{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9cd05212-388f-4702-a1bd-d4db19afc8fa",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# Quantum Generative Adversarial Networks (QGANs)\n",
    "***\n",
    "Generative AI, especially through Generative Adversarial Networks (GANs), revolutionizes content creation across various domains by producing highly realistic outputs. Quantum GANs further elevate this potential by leveraging quantum computing, promising unprecedented advancements in complex data simulation and analysis. \n",
    "***\n",
    "In this notebook, we'll explore the concept of Quantum Generative Adversarial Networks (QGANs) and implement a simple QGAN model using the Classiq SDK.\n",
    "\n",
    "We study a simple usecase of Bars and Stripes dataset. We begin with a classical implementation of a GAN, and then move to a hybrid quantum-classical GAN model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbc47d04-891a-4ec2-ad39-5b146622fc75",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## 1 Data Preparation\n",
    "\n",
    "We generate the Bars and Stripes dataset, a simple binary dataset consisting of 2x2 images with either a horizontal or vertical stripe pattern."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "90eed60c6bc07ca5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:01:58.670187Z",
     "start_time": "2024-02-21T14:01:58.613757Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:36.737287Z",
     "iopub.status.busy": "2024-05-07T14:48:36.735197Z",
     "iopub.status.idle": "2024-05-07T14:48:36.833968Z",
     "shell.execute_reply": "2024-05-07T14:48:36.833197Z"
    }
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "# Function to create Bars and Stripes dataset\n",
    "def create_bars_and_stripes_dataset(num_samples):\n",
    "    samples = []\n",
    "    for _ in range(num_samples):\n",
    "        horizontal = np.random.randint(0, 2) == 0\n",
    "        if horizontal:\n",
    "            stripe = np.random.randint(0, 2, size=(2, 1))\n",
    "            sample = np.tile(stripe, (1, 2))\n",
    "        else:\n",
    "            stripe = np.random.randint(0, 2, size=(1, 2))\n",
    "            sample = np.tile(stripe, (2, 1))\n",
    "        samples.append(sample)\n",
    "    return np.array(samples, dtype=np.uint8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3306f1cb17fc77af",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:01:58.685166Z",
     "start_time": "2024-02-21T14:01:58.676420Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:36.838762Z",
     "iopub.status.busy": "2024-05-07T14:48:36.837575Z",
     "iopub.status.idle": "2024-05-07T14:48:36.887924Z",
     "shell.execute_reply": "2024-05-07T14:48:36.887141Z"
    }
   },
   "outputs": [],
   "source": [
    "# Generate Bars and Stripes dataset\n",
    "dataset = create_bars_and_stripes_dataset(num_samples=1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8da3493f26da0",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 1.1 Visualizing the generated data \n",
    "Let's plot a few samples from the dataset to visualize the Bars and Stripes patterns:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d352b31c-c63e-4c69-99b4-aa6bb5c8e56d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-28T12:04:38.504739Z",
     "start_time": "2024-02-28T12:04:38.351180Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:36.893260Z",
     "iopub.status.busy": "2024-05-07T14:48:36.891930Z",
     "iopub.status.idle": "2024-05-07T14:48:37.394156Z",
     "shell.execute_reply": "2024-05-07T14:48:37.393361Z"
    }
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "\n",
    "\n",
    "# Plot images in a 3 by 3 grid\n",
    "def plot_nine_images(generated_images):\n",
    "    # Define custom colormap\n",
    "    classiq_cmap = LinearSegmentedColormap.from_list(\n",
    "        \"teal_white\", [(0, \"#00FF00\"), (1, \"black\")]\n",
    "    )\n",
    "    fig, axes = plt.subplots(3, 3, figsize=(6, 6))\n",
    "    for i, ax in enumerate(axes.flat):\n",
    "        ax.imshow(generated_images[i].reshape(2, 2), cmap=classiq_cmap, vmin=0, vmax=1)\n",
    "        ax.axis(\"off\")\n",
    "        ax.set_title(f\"Image {i+1}\")\n",
    "        for j in range(2):\n",
    "            for k in range(2):\n",
    "                label = int(generated_images[i].reshape(2, 2)[j, k])\n",
    "                ax.text(\n",
    "                    k,\n",
    "                    j,\n",
    "                    f\"{label}\",\n",
    "                    ha=\"center\",\n",
    "                    va=\"center\",\n",
    "                    color=\"white\" if label == 1 else \"black\",\n",
    "                    fontsize=16,\n",
    "                )\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c7dacda8-62ab-48cc-b2c3-ff94f9f006ee",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-28T12:04:38.504739Z",
     "start_time": "2024-02-28T12:04:38.351180Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:37.399070Z",
     "iopub.status.busy": "2024-05-07T14:48:37.397890Z",
     "iopub.status.idle": "2024-05-07T14:48:38.190650Z",
     "shell.execute_reply": "2024-05-07T14:48:38.189926Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAJOCAYAAABLBSanAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA12UlEQVR4nO3df3RU9Z3/8deAEkKG8Bux4Uc0SljRShC1B6lAkEI8lBVULGzlh+iCbjBwKtVD1vItdLMKqNCD6K5okBpEsQeUWmVFwF9YBVHAdTliSSHGwpHIzxg0TT7fPzLMMRDiZCS5977n+bhnzp3cuQPvvD/znryYDEnIOecEAABgRDOvCwAAADibCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFPPhZtmyZQqFQtq6davXpTSqxx57TDfffLO6d++uUCikiRMnel0SfCYRZqGkpES//e1vddVVV6ldu3bq2LGjBg0apPXr13tdGnwiEeagoqJCkydP1qWXXqo2bdooHA7r8ssv16JFi1RZWel1eU3iHK8LwNnx4IMP6tixY7rqqqv097//3etyAE+8+OKLevDBB3XDDTdowoQJ+sc//qHly5dr6NCheuqppzRp0iSvSwQaXUVFhf73f/9X119/vdLT09WsWTNt3rxZM2bM0HvvvacVK1Z4XWKjI9wY8cYbb0RftQmHw16XA3hi8ODB2rdvnzp27Bg9NnXqVPXp00e/+c1vCDdICO3bt9df/vKXWsemTp2qNm3aaPHixXr44YfVpUsXj6prGua/LVWXiRMnKhwOa9++fRoxYoTC4bDS0tL06KOPSpJ27typ7OxspaSkqEePHqel3K+++kr33HOPLrvsMoXDYaWmpionJ0fbt28/7e/au3evRo4cqZSUFHXu3FkzZszQunXrFAqFtGnTplrnvvfeexo+fLjatGmjVq1aaeDAgXrnnXdi+px69OihUCgUX0OQsKzNQu/evWsFG0lKSkrS9ddfr88//1zHjh1rYIeQCKzNwZmkp6dLkg4fPhz3nxEUCRluJKmqqko5OTnq1q2b5s2bp/T0dOXm5mrZsmUaPny4+vXrpwcffFCtW7fW+PHjVVxcHL3vnj17tGbNGo0YMUIPP/ywZs6cqZ07d2rgwIH64osvoueVl5crOztb69ev19133638/Hxt3rxZ995772n1bNiwQddee62OHj2q2bNnq6CgQIcPH1Z2drbef//9JukJElMizML+/fvVqlUrtWrVKq77wz6Lc/Dtt9/q4MGDKikp0erVq7VgwQL16NFDF1100Q9vmN854woLC50kt2XLluixCRMmOEmuoKAgeuzQoUMuOTnZhUIht3LlyujxXbt2OUlu9uzZ0WMnTpxwVVVVtf6e4uJil5SU5ObMmRM99tBDDzlJbs2aNdFjFRUVrlevXk6S27hxo3POuerqanfxxRe7YcOGuerq6ui5X3/9tbvgggvc0KFDG/Q5p6SkuAkTJjToPrAvEWfBOed2797tWrZs6W699dYG3xf2JNIcPPvss05S9NKvXz+3Y8eOmO4bdAn7yo0k3X777dHrbdu2VWZmplJSUjRmzJjo8czMTLVt21Z79uyJHktKSlKzZjWtq6qqUllZmcLhsDIzM7Vt27boea+++qrS0tI0cuTI6LGWLVvqjjvuqFXHRx99pN27d2vcuHEqKyvTwYMHdfDgQZWXl2vIkCF68803VV1dfdY/f+Akq7Pw9ddf6+abb1ZycrIeeOCB2BuChGRtDgYPHqzXXntNq1at0tSpU3XuueeqvLy84Y0JoIR9Q3HLli3VqVOnWsfatGmjrl27nvbelTZt2ujQoUPRj6urq7Vo0SItWbJExcXFqqqqit7WoUOH6PW9e/cqIyPjtD/v1JcEd+/eLUmaMGHCGes9cuSI2rVrF+NnB8TO6ixUVVXpF7/4hT755BO98sor+tGPfvS990HisjgH5513ns477zxJ0k033aSCggINHTpUu3fvNv+G4oQNN82bN2/Qcedc9HpBQYHuv/9+3XbbbZo7d67at2+vZs2aafr06XG9wnLyPvPnz1efPn3qPIf/AYXGYnUW7rjjDv3pT39SUVGRsrOzG1wLEovVOfium266Sfn5+XrxxRc1ZcqUBt8/SBI23PwQL7zwggYPHqwnn3yy1vHDhw/X+p8aPXr00CeffCLnXK2k/tlnn9W6X0ZGhiQpNTVV1113XSNWDpxdfp2FmTNnqrCwUAsXLtTYsWPj/nOAWPh1Dk5VUVEhqeZVH+sS+j038WrevHmt1C5Jq1atUmlpaa1jw4YNU2lpqV566aXosRMnTuiJJ56odd4VV1yhjIwMLViwQMePHz/t7/vyyy/PYvXA2ePHWZg/f74WLFigWbNmKS8vryGfDhAXv83BwYMHT6tHkpYuXSpJ6tevX/2fkAG8chOHESNGaM6cOZo0aZL69++vnTt3qqioSBdeeGGt86ZMmaLFixdr7NixysvL0/nnn6+ioiK1bNlSkqLJvVmzZlq6dKlycnLUu3dvTZo0SWlpaSotLdXGjRuVmpqqtWvX1lvT2rVroz9TobKyUjt27NDvfvc7SdLIkSP14x//+Gy3AfDdLKxevVq//vWvdfHFF+uf/umf9Mwzz9S6fejQodH3IABni9/m4JlnntHjjz+uG264QRdeeKGOHTumdevW6bXXXtPPf/7zhPg2LeEmDrNmzVJ5eblWrFih5557Tn379tXLL7+s++67r9Z54XBYGzZs0LRp07Ro0SKFw2GNHz9e/fv314033hh9QEvSoEGD9O6772ru3LlavHixjh8/ri5duujqq6+O6Xujf/zjH/X0009HP/7www/14YcfSpK6du1KuEGj8NssnAz4u3fv1q233nra7Rs3biTc4Kzz2xwMGDBAmzdv1rPPPqsDBw7onHPOUWZmph5++GFNmzatUXrgNyFX12tXaFQLFy7UjBkz9PnnnystLc3rcgDPMAsAc9AYCDeNrKKiQsnJydGPT5w4oaysLFVVVenTTz/1sDKgaTELAHPQVPi2VCMbPXq0unfvrj59+ujIkSN65plntGvXLhUVFXldGtCkmAWAOWgqhJtGNmzYMC1dulRFRUWqqqrSJZdcopUrV+qWW27xujSgSTELAHPQVPi2FAAAMIWfcwMAAEwh3AAAAFMINwAAwJSY31B86m8xRQPwrqa4OR82LyRmAU3Pb7PAHMALsc4Br9wAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMCGW569uyp3NxcFRYWaseOHaqsrJRzTvn5+V6XFhyrJA2S1E5SiqTLJc2TVOlhTYgPaxk/emcL6xk/a71zMZLkm8sjjzxSZ435+fme11bnxW9bXqSuc+T0MzmNllPbyLEBcvra8wqjmx953ZNaW54Cs5a+2/IUqN75jdf9OG3LU6DW01dbngLTu9gfn7Ge6HVA+M5l8uTJbt68eW7s2LEuMzPTPf300845wk1M2+pITWE5ffCd41/K6bLIbb/yvMro5kde9yS6rVag1tJX22oFrnd+43U/am2rFbj19M22WoHqXeyPz1hP9Dog1HMpLCx0zhFuYtqujNT0uzpueytyW5KcDnteqZPz3xO6cz56Ug/YWvpqC2Dv/MbrftTaArievtkC1rtYBfI9N4hTqaQtkevj6rh9gKRukr6R9OemKgpxYS3jR+9sYT3jZ7h3hJtE8mFk317SBWc4p98p58KfWMv40TtbWM/4Ge4d4SaRFEf23es5p9sp58KfWMv40TtbWM/4Ge4d4SaRHIvsU+o5JxzZH23kWvDDsJbxo3e2sJ7xM9w7wg0AADCFcJNIWkf25fWcczyyT23kWvDDsJbxo3e2sJ7xM9w7wk0iSY/sS+o55+Rt6fWcA++lR/asZcOlR/b0zob0yJ71bLj0yN5g7wg3iSQrsi/Tmd8ctjWy79v45eAHYC3jR+9sYT3jZ7h3hJtE0lXSlZHrK+q4/W3VpPQkSdc3VVGIC2sZP3pnC+sZP8O9I9wkmlmR/QOStn3neJmkuyLXcyW1acqiEBfWMn70zhbWM35GexdyzrmYTgyFGruWmGVlZWnJkiXRjzMyMtSpUyeVlJSotLQ0enzUqFHav3+/FyXWFlOHm1CepN9LOlfSENX8N8DXJR2WdI2k1yQle1Vcbc53zZNC8s8sBGktfSdgvfPbLPhqDqTAraevBKh3sc5BIMPNwIEDtWnTpu89Lz09XXv37m38gr6Pv56Tajwv6VFJH6nmV9pnSPqlpBmSWnhX1qn89oQu+fBJPSBr6UsB6p3fZsF3cyAFaj19JyC9Mx1uAsdfz0mB4rcndMmnT+owz2+zwBzAC7HOAe+5AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhyjtcFAIET8roAJCTndQFAcPDKDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwJZDhpmfPnsrNzVVhYaF27NihyspKOeeUn5/vdWnBsUrSIEntJKVIulzSPEmVHtaEBmEO4kfvjOJ5LX7GeneO1wXE484779T06dO9LiO4pktapJrVz5YUlrRB0r2S1kr6H0nJXhWHWDEH8aN3Bk0Xz2vxmi5zvQvkKzcff/yx5s+fr3HjxqlXr15avny51yUFxxrVPIjDkt6TtE7SHyXtlnSZpLcl3e9VcWgI5iB+9M6YNeJ5LV5rZLJ3gXzl5sknn6z1cXV1tUeVBFBBZH+fpL7fOd5R0hJJP5W0WDUP5jZNWxoahjmIH70zhue1+BntXSBfuUGcSiVtiVwfV8ftAyR1k/SNpD83VVEA8APwvBY/w70j3CSSDyP79pIuOMM5/U45FwD8jOe1+BnuHeEmkRRH9t3rOafbKecCgJ/xvBY/w70j3CSSY5F9Sj3nhCP7o41cCwCcDTyvxc9w7wg3AADAFMJNImkd2ZfXc87xyD61kWsBgLOB57X4Ge4d4SaRpEf2JfWcc/K29HrOAQC/SI/seV5ruPTI3mDvCDeJJCuyL9OZ3xy2NbLve4bbAcBPeF6Ln+HeEW4SSVdJV0aur6jj9rdVk9KTJF3fVEUBwA/A81r8DPeOcJNoZkX2D0ja9p3jZZLuilzPVaB+EiWABMfzWvyM9i7knHMxnRgKNXYtMcvKytKSJUuiH2dkZKhTp04qKSlRaWlp9PioUaO0f/9+L0qsLaYON6E8Sb+XdK6kIar5b4CvSzos6RpJr8k3vyTN+a55/pmFwM2BjwSxdzE+VTeZkPwxB1EBel7znQD1LuavCS5GqvkS7YvLwIEDY6q5R48entcaWQv/bc/J6Vo5pcopWU6XyukBOX3jeWW1Nj/y/PEU1Dnw0SWIvfMb+XELyPOaL7eA9C5WgXzlJnBi6jDq4nzYPGYBXojxqbrJ+O6VGySEWL8m8J4bAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCkh55zzuggAAICzhVduAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYYj7cLFu2TKFQSFu3bvW6lCbz9ttvKxQKKRQK6eDBg16XA59IlFk4+dg/9fLAAw94XRp8IFHmQJIOHDigKVOmKC0tTS1btlR6eromT57sdVlN4hyvC8DZVV1drWnTpiklJUXl5eVelwN4YujQoRo/fnytY1lZWR5VAzS9kpISXXPNNZKkqVOnKi0tTV988YXef/99jytrGoQbY/77v/9bJSUluv3227Vo0SKvywE80bNnT/3yl7/0ugzAM1OmTNE555yjLVu2qEOHDl6X0+TMf1uqLhMnTlQ4HNa+ffs0YsQIhcNhpaWl6dFHH5Uk7dy5U9nZ2UpJSVGPHj20YsWKWvf/6quvdM899+iyyy5TOBxWamqqcnJytH379tP+rr1792rkyJFKSUlR586dNWPGDK1bt06hUEibNm2qde57772n4cOHq02bNmrVqpUGDhyod955J+bP66uvvtK///u/a86cOWrbtm2D+4LEY3UWJKmiokInTpxoWEOQkKzNwa5du/TKK69o5syZ6tChg06cOKHKysr4GxRACRluJKmqqko5OTnq1q2b5s2bp/T0dOXm5mrZsmUaPny4+vXrpwcffFCtW7fW+PHjVVxcHL3vnj17tGbNGo0YMUIPP/ywZs6cqZ07d2rgwIH64osvoueVl5crOztb69ev19133638/Hxt3rxZ995772n1bNiwQddee62OHj2q2bNnq6CgQIcPH1Z2dnbMLyPef//96tKli6ZMmfLDG4SEYXEWli1bppSUFCUnJ+uSSy457YsRcCpLc7B+/XpJ0nnnnachQ4YoOTlZycnJysnJ0d/+9rez0zC/c8YVFhY6SW7Lli3RYxMmTHCSXEFBQfTYoUOHXHJysguFQm7lypXR47t27XKS3OzZs6PHTpw44aqqqmr9PcXFxS4pKcnNmTMneuyhhx5yktyaNWuixyoqKlyvXr2cJLdx40bnnHPV1dXu4osvdsOGDXPV1dXRc7/++mt3wQUXuKFDh37v57l9+3bXvHlzt27dOuecc7Nnz3aS3Jdffvm990ViSJRZ6N+/v1u4cKF78cUX3WOPPeYuvfRSJ8ktWbLk+5sE8xJhDu6++24nyXXo0MENHz7cPffcc27+/PkuHA67jIwMV15eHluzAixhX7mRpNtvvz16vW3btsrMzFRKSorGjBkTPZ6Zmam2bdtqz5490WNJSUlq1qymdVVVVSorK1M4HFZmZqa2bdsWPe/VV19VWlqaRo4cGT3WsmVL3XHHHbXq+Oijj7R7926NGzdOZWVlOnjwoA4ePKjy8nINGTJEb775pqqrq+v9XO6++27l5OToZz/7WXzNQEKzNAvvvPOO8vLyNHLkSE2dOlUffPCBLr30Us2aNUsVFRXxNQgJwcocHD9+XJLUpUsXvfzyyxozZozuuecePfHEE/rrX/+aEK9kJuwbilu2bKlOnTrVOtamTRt17dpVoVDotOOHDh2KflxdXa1FixZpyZIlKi4uVlVVVfS2775xa+/evcrIyDjtz7voootqfbx7925J0oQJE85Y75EjR9SuXbs6b3vuuee0efNmffzxx2e8P3AmlmahLi1atFBubm406AwYMCDm+yJxWJqD5ORkSdKYMWOioUuSbr75Zt16663avHlzrSBnUcKGm+bNmzfouHMuer2goED333+/brvtNs2dO1ft27dXs2bNNH369O/9V2VdTt5n/vz56tOnT53nhMPhM95/5syZuvnmm9WiRYvo91MPHz4sqea/A3777bf60Y9+1OC6kBgszcKZdOvWTVLNGz+Buliag5PP9+edd16t482bN1eHDh1qBTOrEjbc/BAvvPCCBg8erCeffLLW8cOHD6tjx47Rj3v06KFPPvlEzrlaSf2zzz6rdb+MjAxJUmpqqq677roG11NSUqIVK1bU+VJj3759dfnll+ujjz5q8J8LfB+/zcKZnPwWwqn/MgfOBr/NwRVXXCFJKi0trXX822+/1cGDBxNiDhL6PTfxat68ea3ULkmrVq067YE0bNgwlZaW6qWXXooeO3HihJ544ola511xxRXKyMjQggULot8r/a4vv/yy3npWr1592uWWW26RJC1fvlyPPPJIgz4/IFZ+m4W6bj927JgWLlyojh07Rp/0gbPJb3MwaNAgde7cWUVFRbV+HMKyZctUVVWloUOHxvy5BRWv3MRhxIgRmjNnjiZNmqT+/ftr586dKioq0oUXXljrvClTpmjx4sUaO3as8vLydP7556uoqEgtW7aUpGhyb9asmZYuXaqcnBz17t1bkyZNUlpamkpLS7Vx40alpqZq7dq1Z6znhhtuOO3YyVdqcnJyav3LATib/DYLjz76qNasWaOf//zn6t69u/7+97/rqaee0r59+/SHP/xBLVq0aLxmIGH5bQ6SkpI0f/58TZgwQddee61uvfVW7du3T4sWLdJPf/pTjR49uvGa4ROEmzjMmjVL5eXlWrFihZ577jn17dtXL7/8su67775a54XDYW3YsEHTpk3TokWLFA6HNX78ePXv31833nhj9AEt1STtd999V3PnztXixYt1/PhxdenSRVdffTU/twa+5bdZuOaaa7R582YtXbpUZWVlSklJ0VVXXaWnnnpK2dnZjdIDwG9zIEnjx49XixYt9MADD2jmzJlq27atpkyZooKCgjO+j8iSkDv1tTQ0uoULF2rGjBn6/PPPlZaW5nU5gGeYBYA5aAyEm0ZWUVER/W95Us33V7OyslRVVaVPP/3Uw8qApsUsAMxBU+HbUo1s9OjR6t69u/r06aMjR47omWee0a5du1RUVOR1aUCTYhYA5qCpEG4a2bBhw7R06VIVFRWpqqpKl1xyiVauXBn930xAomAWAOagqfBtKQAAYAo/5wYAAJhCuAEAAKYQbgAAgCkxv6E4pND3nwScZU7+e0sYs/AD0Lq4+e3tkcwBvBDr1wReuQEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAApgQ73KySNEhSO0kpki6XNE9SpYc1BQW9s4O1bLCePXsqNzdXhYWF2rFjhyorK+WcU35+vtel4YdgFuJnrXcuRvLblicnyekcOf1MTqPl1DZybICcvva8Qv9ueQpM7/zI657U2vIUmLWUi9Tlg8sjjzxS59rm5+d7XtuZLn7jyeOnvi1PwZoFP215CkzvYn98BvGBvFo1TQ/L6YPvHP9STpdFbvuV51X6c1utQPXOj7zuSXRbrUCtpZw8DwgnL5MnT3bz5s1zY8eOdZmZme7pp592zhFuGsKTx8+ZttUK3iz4ZVutQPUu9sdnEB/IV6qm4b+r47a3Irclyemw55X6bwtY7/zI655Et4CtpZw8DwhnuhQWFjrnCDcN0aSPm+/bgjgLftkC1rtYBe89N6WStkSuj6vj9gGSukn6RtKfm6qogKB3drCWQA1mIX6Gexe8cPNhZN9e0gVnOKffKeeiBr2zg7UEajAL8TPcu+CFm+LIvns953Q75VzUoHd2sJZADWYhfoZ7F7xwcyyyT6nnnHBkf7SRawkaemcHawnUYBbiZ7h3wQs3AAAA9QheuGkd2ZfXc87xyD61kWsJGnpnB2sJ1GAW4me4d8ELN+mRfUk955y8Lb2ecxJRemRP74IvPbJnLZHo0iN7ZqHh0iN7g70LXrjJiuzLdOY3OG2N7Ps2fjmBQu/sYC2BGsxC/Az3LnjhpqukKyPXV9Rx+9uqSZpJkq5vqqICgt7ZwVoCNZiF+BnuXfDCjSTNiuwfkLTtO8fLJN0VuZ4rqU1TFhUQ9M4O1hKowSzEz2jvQs45F9OJCjV2LQ2TJ+n3ks6VNEQ1/5XtdUmHJV0j6TVJyV4V53MB6p1TTA/PJuWrWQjQWkqSX1qXlZWlJUuWRD/OyMhQp06dVFJSotLS0ujxUaNGaf/+/V6UeJoYn6qbjK/mQAreLPhJgHoX69eE4IYbSXpe0qOSPlLNr2XPkPRLSTMktfCurEAISO8INzEIyFpK8k24GThwoDZt2vS956Wnp2vv3r2NX1AMCDcxCNIs+E1AepcY4QbmEW6MoXVxI9wAsX9NCOZ7bgAAAM6AcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwJdrhZJWmQpHaSUiRdLmmepEoPawoKemcHa9lgPXv2VG5urgoLC7Vjxw5VVlbKOaf8/HyvS8MPwSzEz1rvXIzkty1PTpLTOXL6mZxGy6lt5NgAOX3teYX+3fIUmN75kdc9qbXlKTBrKRepyweXRx55pM61zc/P97y2M138xpPHT31bnoI1C37a8hSY3sX++AziA3m1apoeltMH3zn+pZwui9z2K8+r9Oe2WoHqnR953ZPotlqBWks5eR4QTl4mT57s5s2b58aOHesyMzPd008/7Zwj3DSEJ4+fM22rFbxZ8Mu2WoHqXeyPzyA+kK9UTcN/V8dtb0VuS5LTYc8r9d8WsN75kdc9iW4BW0s5eR4QznQpLCx0zhFuGqJJHzfftwVxFvyyBax3sQree25KJW2JXB9Xx+0DJHWT9I2kPzdVUQFB7+xgLYEazEL8DPcueOHmw8i+vaQLznBOv1PORQ16ZwdrCdRgFuJnuHfBCzfFkX33es7pdsq5qEHv7GAtgRrMQvwM9y544eZYZJ9SzznhyP5oI9cSNPTODtYSqMEsxM9w74IXbgAAAOoRvHDTOrIvr+ec45F9aiPXEjT0zg7WEqjBLMTPcO+CF27SI/uSes45eVt6PeckovTInt4FX3pkz1oi0aVH9sxCw6VH9gZ7F7xwkxXZl+nMb3DaGtn3bfxyAoXe2cFaAjWYhfgZ7l3wwk1XSVdGrq+o4/a3VZM0kyRd31RFBQS9s4O1BGowC/Ez3LvghRtJmhXZPyBp23eOl0m6K3I9V1KbpiwqIOidHawlUINZiJ/R3oWccy6mExVq7FoaJk/S7yWdK2mIav4r2+uSDku6RtJrkpK9Ks7nAtQ7p5genk3KV7MQoLWUJL+0LisrS0uWLIl+nJGRoU6dOqmkpESlpaXR46NGjdL+/fu9KPE0MT5VNxlfzYEUvFnwkwD1LtavCcENN5L0vKRHJX2kml/LniHpl5JmSGrhXVmBEJDeEW5iEJC1lOSbcDNw4EBt2rTpe89LT0/X3r17G7+gGBBuYhCkWfCbgPQuMcINzCPcGEPr4ka4AWL/mhDM99wAAACcAeEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCkh55zzuggAAICzhVduAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYYj7cLFu2TKFQSFu3bvW6lEZz8nM806WoqMjrEuEDiTALknTkyBH9+te/1sUXX6zk5GT16NFDkydP1r59+7wuDT6QKHNw4MABTZo0SZ07d1ZycrL69u2rVatWeV1WkznH6wLww1177bX6wx/+cNrxRx55RNu3b9eQIUM8qApoetXV1Ro6dKg++eQT3XXXXerZs6c+++wzLVmyROvWrdP//d//qXXr1l6XCTSqo0ePasCAATpw4IDy8vLUpUsXPf/88xozZoyKioo0btw4r0tsdIQbAy688EJdeOGFtY5VVFTorrvuUnZ2trp06eJRZUDT+stf/qItW7Zo8eLF+rd/+7fo8czMTN12221av369Ro0a5WGFQOP7r//6L3322Wd6/fXXlZ2dLUm688479ZOf/ES/+tWvdNNNN6lFixYeV9m4zH9bqi4TJ05UOBzWvn37NGLECIXDYaWlpenRRx+VJO3cuVPZ2dlKSUlRjx49tGLFilr3/+qrr3TPPffosssuUzgcVmpqqnJycrR9+/bT/q69e/dq5MiRSklJUefOnTVjxgytW7dOoVBImzZtqnXue++9p+HDh6tNmzZq1aqVBg4cqHfeeSeuz3Ht2rU6duyY/uVf/iWu+yMxWJuFo0ePSpLOO++8WsfPP/98SVJycnLMvUHisDYHb731ljp16hQNNpLUrFkzjRkzRvv379cbb7wRR5eCJSHDjSRVVVUpJydH3bp107x585Senq7c3FwtW7ZMw4cPV79+/fTggw+qdevWGj9+vIqLi6P33bNnj9asWaMRI0bo4Ycf1syZM7Vz504NHDhQX3zxRfS88vJyZWdna/369br77ruVn5+vzZs369577z2tng0bNujaa6/V0aNHNXv2bBUUFOjw4cPKzs7W+++/3+DPr6ioSMnJyRo9enR8DULCsDQL/fr1U0pKiu6//35t2LBBpaWleuONN/TrX/9aV155pa677rqz1ziYYmkOvvnmmzqDfKtWrSRJH3zwQbxtCg5nXGFhoZPktmzZEj02YcIEJ8kVFBREjx06dMglJye7UCjkVq5cGT2+a9cuJ8nNnj07euzEiROuqqqq1t9TXFzskpKS3Jw5c6LHHnroISfJrVmzJnqsoqLC9erVy0lyGzdudM45V11d7S6++GI3bNgwV11dHT3366+/dhdccIEbOnRogz7nsrIy16JFCzdmzJgG3Q+2Jcos/OlPf3Lnn3++kxS9DBs2zB07duz7mwTzEmEOpk2b5po1a+b+9re/1Tr+i1/8wklyubm59d7fgoR95UaSbr/99uj1tm3bKjMzUykpKRozZkz0eGZmptq2bas9e/ZEjyUlJalZs5rWVVVVqaysTOFwWJmZmdq2bVv0vFdffVVpaWkaOXJk9FjLli11xx131Krjo48+0u7duzVu3DiVlZXp4MGDOnjwoMrLyzVkyBC9+eabqq6ujvnzeuGFF/Ttt9/yLSnEzNIsdOrUSVlZWfqP//gPrVmzRv/v//0/vfXWW5o0aVJ8zUHCsDIHt99+u5o3b64xY8Zo8+bN+utf/6r//M//1OrVqyXVvCfTuoR9Q3HLli3VqVOnWsfatGmjrl27KhQKnXb80KFD0Y+rq6u1aNEiLVmyRMXFxaqqqore1qFDh+j1vXv3KiMj47Q/76KLLqr18e7duyVJEyZMOGO9R44cUbt27WL63IqKitS+fXvl5OTEdD4Sm6VZ2LNnjwYPHqzly5frxhtvlCT98z//s9LT0zVx4kS98sorzAXqZGkOfvzjH2vFihWaOnWqrrnmGklSly5dtHDhQt15550Kh8Nn/HOtSNhw07x58wYdd85FrxcUFOj+++/Xbbfdprlz56p9+/Zq1qyZpk+f3qBXWE46eZ/58+erT58+dZ4T64Nx3759euutt/Sv//qvOvfccxtcCxKPpVlYtmyZTpw4oREjRtQ6fvJfyu+88w7hBnWyNAeSdNNNN2nkyJHavn27qqqq1Ldv3+gblnv27NngmoImYcPND/HCCy9o8ODBevLJJ2sdP3z4sDp27Bj9uEePHvrkk0/knKuV1D/77LNa98vIyJAkpaam/uA3PD777LNyzvEtKTQJv83CgQMH5Jyr9S9nSaqsrJQk/eMf/2jwnwl8H7/NwUktWrTQlVdeGf14/fr1kpQQb6xP6PfcxKt58+a1UrskrVq1SqWlpbWODRs2TKWlpXrppZeix06cOKEnnnii1nlXXHGFMjIytGDBAh0/fvy0v+/LL7+MubYVK1aoe/fuGjBgQMz3AeLlt1no2bOnnHN6/vnnax1/9tlnJUlZWVnf/0kBDeS3OajL7t279fjjj2vEiBG8coO6jRgxQnPmzNGkSZPUv39/7dy5U0VFRaf9IL0pU6Zo8eLFGjt2rPLy8nT++eerqKhILVu2lKRocm/WrJmWLl2qnJwc9e7dW5MmTVJaWppKS0u1ceNGpaamau3atd9b18cff6wdO3bovvvuO+17ukBj8NssTJw4UQsWLNCUKVP04Ycfqnfv3tq2bZuWLl2q3r178wP80Cj8NgeSdMkll+jmm29W9+7dVVxcrMcee0zt27fX448/3jhN8BnCTRxmzZql8vJyrVixQs8995z69u2rl19+Wffdd1+t88LhsDZs2KBp06Zp0aJFCofDGj9+vPr3768bb7wx+oCWpEGDBundd9/V3LlztXjxYh0/flxdunTR1VdfrSlTpsRU18nfIZUIP1ob/uC3WejQoYO2bt2q3/zmN1q7dq0ef/xxdejQQbfddpsKCgrM/1RWeMNvcyBJl19+uQoLC3XgwAF17NhRY8aM0W9/+1t17tz5rH/+fhRyp76Whka3cOFCzZgxQ59//rnS0tK8LgfwDLMAMAeNgXDTyCoqKmr9pMgTJ04oKytLVVVV+vTTTz2sDGhazALAHDQVvi3VyEaPHq3u3burT58+OnLkiJ555hnt2rUr+i0kIFEwCwBz0FQIN41s2LBhWrp0qYqKilRVVaVLLrlEK1eu1C233OJ1aUCTYhYA5qCp8G0pAABgCj/nBgAAmEK4AQAAphBuAACAKTG/oZifeAsv+PEtYSExC2h6Tv6aBeYAXoh1DnjlBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYEshw07NnT+Xm5qqwsFA7duxQZWWlnHPKz8/3ujTfo3cGrZI0SFI7SSmSLpc0T1KlhzUFBb2zhfWMn7HeneN1AfG48847NX36dK/LCCR6Z8x0SYtUM8nZksKSNki6V9JaSf8jKdmr4nxuuuidJdPFesZrusz1LpCv3Hz88ceaP3++xo0bp169emn58uVelxQY9M6QNap5QgpLek/SOkl/lLRb0mWS3pZ0v1fF+dwa0TtL1oj1jNca2eydi5Ek314KCwudc87l5+d7XkvQLn7vnR/JL9uVkT79ro7b3orcliSnw55X6r8tgL3zG6/7UWsL4Hr6ZgtY72IVyFdugIRXKmlL5Pq4Om4fIKmbpG8k/bmpigoIemcL6xk/w70j3ABB9GFk317SBWc4p98p56IGvbOF9Yyf4d4RboAgKo7su9dzTrdTzkUNemcL6xk/w70j3ABBdCyyT6nnnHBkf7SRawkaemcL6xk/w70j3AAAAFMIN0AQtY7sy+s553hkn9rItQQNvbOF9Yyf4d4RboAgSo/sS+o55+Rt6fWck4jSI3t6Z0N6ZM96Nlx6ZG+wd4QbIIiyIvsynfmNflsj+76NX06g0DtbWM/4Ge4d4QYIoq6SroxcX1HH7W+r5l9cSZKub6qiAoLe2cJ6xs9w7wg3QFDNiuwfkLTtO8fLJN0VuZ4rqU1TFhUQ9M4W1jN+RnsXcs65mE4MhRq7lphlZWVpyZIl0Y8zMjLUqVMnlZSUqLS0NHp81KhR2r9/vxcl+lbQehfjw7NJheSfWVCepN9LOlfSENX8l87XJR2WdI2k1xS4X3jXZALWOyd/zYKv5kAK3Hr6SoB6F/McxPx7RHzwe4ZOXgYOHBhTzT169PC8Vr9dgtY7P5LftufkdK2cUuWULKdL5fSAnL7xvDL/bwHqnd943Y86twCtp++2gPQuVoF85QaJI8aHZ5Py3b9YkRCc/DULzAG8EOsc8J4bAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCnneF0AEDghrwtAQnJeFwAEB6/cAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTAhluevbsqdzcXBUWFmrHjh2qrKyUc075+flel+Z79M4O1jJ+9M6oVZIGSWonKUXS5ZLmSar0sKagMNa7c7wuIB533nmnpk+f7nUZgUTv7GAt40fvDJouaZFqvqplSwpL2iDpXklrJf2PpGSvivO56TLXu0C+cvPxxx9r/vz5GjdunHr16qXly5d7XVJg0Ds7WMv40Ttj1qjmi3NY0nuS1kn6o6Tdki6T9Lak+70qzufWyGTvAvnKzZNPPlnr4+rqao8qCR56ZwdrGT96Z0xBZH+fpL7fOd5R0hJJP5W0WDVfpNs0bWm+Z7R3gXzlBgAASVKppC2R6+PquH2ApG6SvpH056YqKiAM945wAwAIrg8j+/aSLjjDOf1OORc1DPeOcAMACK7iyL57Ped0O+Vc1DDcO8INACC4jkX2KfWcE47sjzZyLUFjuHeEGwAAYArhBgAQXK0j+/J6zjke2ac2ci1BY7h3hBsAQHClR/Yl9Zxz8rb0es5JROmRvcHeEW4AAMGVFdmX6cxvet0a2fc9w+2JynDvCDcAgODqKunKyPUVddz+tmpefUiSdH1TFRUQhntHuAEABNusyP4BSdu+c7xM0l2R67kK1E/YbTJGexdyzrmYTgyFGruWmGVlZWnJkiXRjzMyMtSpUyeVlJSotLQ0enzUqFHav3+/FyX6VtB6F+PDs0n5ZRaCtpZ+EsTe+W0WQvLHHETlSfq9pHMlDVHNf29+XdJhSddIek2B++WPTSZAvXOKbQ4C+bulUlNT9ZOf/OS04926dVO3bt2iHyclJTVlWYFA7+xgLeNH7wxapJovxI9K2iypUlKGan5n0gxJLbwrzfcM9i6Qr9wgcfjtX6sSswBv+G0WfPfKDRJCrK/c8J4bAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCkh55zzuggAAICzhVduAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCn/H1Qs/yZ88sFLAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 600x600 with 9 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Generate images\n",
    "generated_images = create_bars_and_stripes_dataset(9)\n",
    "plot_nine_images(generated_images)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d106636b-b885-4404-862f-f34fdb59756a",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "We create a PyTorch DataLoader to feed the dataset to the GAN model during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f5d147ea28bc06f2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:01:58.844465Z",
     "start_time": "2024-02-21T14:01:58.842762Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:38.195512Z",
     "iopub.status.busy": "2024-05-07T14:48:38.194338Z",
     "iopub.status.idle": "2024-05-07T14:48:42.056682Z",
     "shell.execute_reply": "2024-05-07T14:48:42.056030Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Create DataLoader for training\n",
    "tensor_dataset = TensorDataset(torch.tensor(dataset, dtype=torch.float))\n",
    "dataloader = DataLoader(tensor_dataset, batch_size=64, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c702902d-0242-4cbf-9bfa-e12e49c7ae38",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## 2 Classical network "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "942b1203-cf34-4a4b-8a45-51dcbd11c847",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 2.1 Defining a classical GAN\n",
    "\n",
    "We begin by defining the generator and discriminator models (architecture) for the classical GAN.\n",
    "We work with `tensorboard` to save our logs (un-comment the following line to install the package)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1157ac4e-2e99-42a1-8d0f-6fc4cfee7e38",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:42.059950Z",
     "iopub.status.busy": "2024-05-07T14:48:42.059306Z",
     "iopub.status.idle": "2024-05-07T14:48:42.062531Z",
     "shell.execute_reply": "2024-05-07T14:48:42.061939Z"
    }
   },
   "outputs": [],
   "source": [
    "# ! pip install tensorboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "79d4cced-1dee-4dcc-9c46-104f33d62020",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:42.064918Z",
     "iopub.status.busy": "2024-05-07T14:48:42.064454Z",
     "iopub.status.idle": "2024-05-07T14:48:42.067818Z",
     "shell.execute_reply": "2024-05-07T14:48:42.067229Z"
    }
   },
   "outputs": [],
   "source": [
    "import pathlib\n",
    "\n",
    "path = (\n",
    "    pathlib.Path(__file__).parent.resolve()\n",
    "    if \"__file__\" in locals()\n",
    "    else pathlib.Path(\".\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e4b989f85ea36274",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:01:58.849345Z",
     "start_time": "2024-02-21T14:01:58.847144Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:42.070432Z",
     "iopub.status.busy": "2024-05-07T14:48:42.069928Z",
     "iopub.status.idle": "2024-05-07T14:48:42.075935Z",
     "shell.execute_reply": "2024-05-07T14:48:42.075306Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "class Generator(nn.Module):\n",
    "    def __init__(self, input_size=2, output_size=4, hidden_size=32):\n",
    "        super(Generator, self).__init__()\n",
    "        self.model = nn.Sequential(\n",
    "            nn.Linear(input_size, hidden_size // 2),  # Adjusted hidden layer size\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden_size // 2, hidden_size),  # Adjusted hidden layer size\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(hidden_size, output_size),\n",
    "            nn.Sigmoid(),  # Sigmoid activation to output probabilities\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return torch.round(self.model(x))\n",
    "\n",
    "\n",
    "class Discriminator(nn.Module):\n",
    "    def __init__(self, input_size=4, hidden_size=16):\n",
    "        super(Discriminator, self).__init__()\n",
    "        self.model = nn.Sequential(\n",
    "            nn.Linear(input_size, hidden_size // 2),  # Adjusted hidden layer size\n",
    "            nn.LeakyReLU(0.2),\n",
    "            nn.Linear(hidden_size // 2, hidden_size),  # Adjusted hidden layer size\n",
    "            nn.LeakyReLU(0.25),\n",
    "            nn.Dropout(0.3),\n",
    "            nn.Linear(hidden_size, 1),\n",
    "            nn.Sigmoid(),  # Sigmoid activation to output probabilities\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, 4)  # Flatten input for fully connected layers\n",
    "        return self.model(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53cf99bc973320d0",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 2.2 Training a classical GAN\n",
    "\n",
    "**Training loop** $\\rightarrow$ Define the training loop for the classical GAN:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "998f0caa7922c78e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:01:58.854405Z",
     "start_time": "2024-02-21T14:01:58.852029Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:42.078506Z",
     "iopub.status.busy": "2024-05-07T14:48:42.078064Z",
     "iopub.status.idle": "2024-05-07T14:48:43.570642Z",
     "shell.execute_reply": "2024-05-07T14:48:43.569939Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from datetime import datetime\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision.utils as vutils\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "\n",
    "def train_gan(\n",
    "    generator,\n",
    "    discriminator,\n",
    "    dataloader,\n",
    "    log_dir_name,\n",
    "    fixed_noise,\n",
    "    random_fake_data_generator,\n",
    "    num_epochs=100,\n",
    "    device=\"cpu\",\n",
    "):\n",
    "\n",
    "    # Initialize TensorBoard writer\n",
    "    run_id = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "    log_dir = os.path.join(log_dir_name, run_id)\n",
    "    writer = SummaryWriter(log_dir=log_dir)\n",
    "\n",
    "    # Define loss function and optimizer\n",
    "    criterion = nn.BCELoss()\n",
    "    g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)\n",
    "    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)\n",
    "\n",
    "    generator.to(device)\n",
    "    discriminator.to(device)\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        for i, batch in enumerate(dataloader):\n",
    "            real_data = batch[0].to(device)\n",
    "            batch_size = real_data.size(0)\n",
    "\n",
    "            # Train Discriminator with real data\n",
    "            d_optimizer.zero_grad()\n",
    "            real_output = discriminator(real_data)\n",
    "            d_real_loss = criterion(real_output, torch.ones_like(real_output))\n",
    "            d_real_loss.backward()\n",
    "\n",
    "            # Train Discriminator with fake data\n",
    "            z = random_fake_data_generator(batch_size)\n",
    "            fake_data = generator(z)\n",
    "            fake_output = discriminator(fake_data.detach())\n",
    "            d_fake_loss = criterion(fake_output, torch.zeros_like(fake_output))\n",
    "            d_fake_loss.backward()\n",
    "            d_optimizer.step()\n",
    "\n",
    "            # Train Generator\n",
    "            g_optimizer.zero_grad()\n",
    "            z = random_fake_data_generator(batch_size)\n",
    "            fake_data = generator(z)\n",
    "            fake_output = discriminator(fake_data)\n",
    "            g_loss = criterion(fake_output, torch.ones_like(fake_output))\n",
    "            g_loss.backward()\n",
    "            g_optimizer.step()\n",
    "\n",
    "            # Log losses to TensorBoard\n",
    "            step = epoch * len(dataloader) + i\n",
    "            writer.add_scalar(\"Generator Loss\", g_loss.item(), step)\n",
    "            writer.add_scalar(\"Discriminator Real Loss\", d_real_loss.item(), step)\n",
    "            writer.add_scalar(\"Discriminator Fake Loss\", d_fake_loss.item(), step)\n",
    "\n",
    "            if i % 100 == 0:\n",
    "                print(\n",
    "                    f\"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], \"\n",
    "                    f\"Generator Loss: {g_loss.item():.4f}, \"\n",
    "                    f\"Discriminator Real Loss: {d_real_loss.item():.4f}, \"\n",
    "                    f\"Discriminator Fake Loss: {d_fake_loss.item():.4f}\"\n",
    "                )\n",
    "\n",
    "        # Generate and log sample images for visualization\n",
    "        # if (epoch+1) % (num_epochs // 10) == 0:\n",
    "        #     with torch.no_grad():\n",
    "        #         generated_images = generator(fixed_noise).detach().cpu()\n",
    "        #     img_grid = vutils.make_grid(generated_images, nrow=3, normalize=True)\n",
    "        #     writer.add_image('Generated Images', img_grid, epoch+1)\n",
    "\n",
    "    # Close TensorBoard writer\n",
    "    writer.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2366f4c-81e0-441a-84ce-0b66d2bf692d",
   "metadata": {},
   "source": [
    "We train our model and save the trained generator in `'generator_model.pth'`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d59ca234-e9b1-4229-84b5-a4e1a9723c7d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:43.576607Z",
     "iopub.status.busy": "2024-05-07T14:48:43.576104Z",
     "iopub.status.idle": "2024-05-07T14:48:43.583199Z",
     "shell.execute_reply": "2024-05-07T14:48:43.582599Z"
    }
   },
   "outputs": [],
   "source": [
    "# Fixed noise for visualizing generated samples\n",
    "fixed_noise = torch.randn(9, 2)\n",
    "\n",
    "\n",
    "def random_fake_data_for_gan(batch_size, input_size):\n",
    "    return torch.randn(batch_size, input_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "38df9d679364cd4a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:01:59.041937Z",
     "start_time": "2024-02-21T14:01:58.854469Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:43.586207Z",
     "iopub.status.busy": "2024-05-07T14:48:43.585866Z",
     "iopub.status.idle": "2024-05-07T14:48:44.272298Z",
     "shell.execute_reply": "2024-05-07T14:48:44.271642Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/10], Step [1/16], Generator Loss: 0.8444, Discriminator Real Loss: 0.8427, Discriminator Fake Loss: 0.5651\n",
      "Epoch [2/10], Step [1/16], Generator Loss: 0.8422, Discriminator Real Loss: 0.8406, Discriminator Fake Loss: 0.5667\n",
      "Epoch [3/10], Step [1/16], Generator Loss: 0.8290, Discriminator Real Loss: 0.8291, Discriminator Fake Loss: 0.5767\n",
      "Epoch [4/10], Step [1/16], Generator Loss: 0.8177, Discriminator Real Loss: 0.8202, Discriminator Fake Loss: 0.5793\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [5/10], Step [1/16], Generator Loss: 0.8161, Discriminator Real Loss: 0.8139, Discriminator Fake Loss: 0.5834\n",
      "Epoch [6/10], Step [1/16], Generator Loss: 0.8090, Discriminator Real Loss: 0.8019, Discriminator Fake Loss: 0.5861\n",
      "Epoch [7/10], Step [1/16], Generator Loss: 0.8042, Discriminator Real Loss: 0.7873, Discriminator Fake Loss: 0.5938\n",
      "Epoch [8/10], Step [1/16], Generator Loss: 0.8005, Discriminator Real Loss: 0.7968, Discriminator Fake Loss: 0.5980\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [9/10], Step [1/16], Generator Loss: 0.7985, Discriminator Real Loss: 0.7806, Discriminator Fake Loss: 0.6012\n",
      "Epoch [10/10], Step [1/16], Generator Loss: 0.7927, Discriminator Real Loss: 0.7774, Discriminator Fake Loss: 0.6042\n"
     ]
    }
   ],
   "source": [
    "generator = Generator(input_size=2, output_size=4, hidden_size=32)\n",
    "discriminator = Discriminator(input_size=4, hidden_size=16)\n",
    "\n",
    "# For simplicitly we load a pretrained model\n",
    "checkpoint = torch.load(path / \"resources/generator_trained_model.pth\")\n",
    "generator.load_state_dict(checkpoint)\n",
    "\n",
    "train_gan(\n",
    "    generator=generator,\n",
    "    discriminator=discriminator,\n",
    "    dataloader=dataloader,\n",
    "    log_dir_name=\"logs\",\n",
    "    fixed_noise=fixed_noise,\n",
    "    random_fake_data_generator=lambda b_size: random_fake_data_for_gan(b_size, 2),\n",
    "    num_epochs=10,\n",
    "    device=\"cpu\",\n",
    ")\n",
    "\n",
    "# Save trained generator model\n",
    "torch.save(generator.state_dict(), path / \"resources/generator_model.pth\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d30d237c39c3d91",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 2.3 Performance evaluation "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "70c784547d6c2882",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:01:59.046590Z",
     "start_time": "2024-02-21T14:01:59.043009Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:44.276661Z",
     "iopub.status.busy": "2024-05-07T14:48:44.275486Z",
     "iopub.status.idle": "2024-05-07T14:48:44.291939Z",
     "shell.execute_reply": "2024-05-07T14:48:44.291218Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classically trained generator accuracy: 83.00%%\n"
     ]
    }
   ],
   "source": [
    "# Load state dictionary with mismatched sizes\n",
    "generator = Generator()\n",
    "checkpoint = torch.load(path / \"resources/generator_model.pth\")\n",
    "generator.load_state_dict(checkpoint)\n",
    "num_samples = 100\n",
    "z = random_fake_data_for_gan(num_samples, 2)\n",
    "gen_data = generator(z)\n",
    "\n",
    "\n",
    "def evaluate_generator(samples):\n",
    "    count_err = 0\n",
    "    for img in samples:\n",
    "        img = img.reshape(2, 2)\n",
    "        diag1 = int(img[0, 0]) * int(img[1, 1])\n",
    "        diag2 = int(img[0, 1]) * (int(img[1, 0]))\n",
    "        if (diag1 == 1 or diag2 == 1) and diag1 * diag2 != 1:\n",
    "            count_err += 1\n",
    "    return (samples.shape[0] - count_err) / samples.shape[0]\n",
    "\n",
    "\n",
    "accuracy = evaluate_generator(samples=gen_data)\n",
    "print(f\"Classically trained generator accuracy: {accuracy:.2%}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf13723d-7f8b-4fa5-82f6-8f299afc5acd",
   "metadata": {},
   "source": [
    "Visualizing generator examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6a70f58e36d1dccc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:01:59.210255Z",
     "start_time": "2024-02-21T14:01:59.048795Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:44.296848Z",
     "iopub.status.busy": "2024-05-07T14:48:44.295618Z",
     "iopub.status.idle": "2024-05-07T14:48:45.112099Z",
     "shell.execute_reply": "2024-05-07T14:48:45.111353Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAJOCAYAAABLBSanAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/xnp5ZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA060lEQVR4nO3de3BUZZrH8V+DEkKacEecAIlGgRUdCaJOIXIJshCLYcULjtmVi5cF3CChVtQiy7ADs1k06MAURncFwzIGYdAyyjjKiOAVB2FQLutSxiEDMQ6URK4xaCY5+0c6XQZC7HRIzjlPfz9dqdN9+m14+n3znP7lpJMEHMdxBAAAYEQbtwsAAAA4nwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTz4WbVqlUKBALasWOH26W0qKefflp33HGH+vbtq0AgoKlTp7pdEjwmFnqhtLRUv/jFL3TdddepS5cu6t69u0aOHKlNmza5XRo8Ihb6oLKyUvfee6+uvPJKderUScFgUFdffbWWLVumqqoqt8trFRe4XQDOj8cee0wnT57Uddddp7/+9a9ulwO44pVXXtFjjz2mW265RVOmTNHf/vY3rV69WmPGjNFzzz2nadOmuV0i0OIqKyv1v//7v7r55puVkpKiNm3aaOvWrZozZ462bdumNWvWuF1iiyPcGPHOO++Ez9oEg0G3ywFcMWrUKB08eFDdu3cP75sxY4YGDRqkn//854QbxISuXbvqj3/8Y719M2bMUKdOnbR8+XI9+eST6tWrl0vVtQ7z35ZqyNSpUxUMBnXw4EGNHz9ewWBQSUlJeuqppyRJe/bsUXp6uhISEpScnHxWyv3666/10EMP6aqrrlIwGFRiYqIyMjK0a9eus/6vAwcOaMKECUpISFDPnj01Z84cbdy4UYFAQG+//Xa9sdu2bdO4cePUqVMndejQQSNGjNAHH3wQ0XNKTk5WIBCIbkIQs6z1wsCBA+sFG0mKi4vTzTffrC+++EInT55s4gwhFljrg3NJSUmRJB07dizqf8MvYjLcSFJ1dbUyMjLUp08fPf7440pJSVFWVpZWrVqlcePGaciQIXrsscfUsWNHTZ48WSUlJeHH7t+/X0VFRRo/fryefPJJzZ07V3v27NGIESP05ZdfhsdVVFQoPT1dmzZt0oMPPqicnBxt3bpVjzzyyFn1bN68WcOHD9eJEye0YMEC5ebm6tixY0pPT9dHH33UKnOC2BQLvXDo0CF16NBBHTp0iOrxsM9iH3z33Xc6cuSISktL9fLLL2vJkiVKTk7WZZdd1vwJ8zrHuIKCAkeSs3379vC+KVOmOJKc3Nzc8L6jR4868fHxTiAQcNauXRvev2/fPkeSs2DBgvC+06dPO9XV1fX+n5KSEicuLs5ZuHBheN8TTzzhSHKKiorC+yorK50BAwY4kpwtW7Y4juM4NTU1zuWXX+6MHTvWqampCY/95ptvnEsuucQZM2ZMk55zQkKCM2XKlCY9BvbFYi84juMUFxc77du3d+6+++4mPxb2xFIfvPDCC46k8MeQIUOc3bt3R/RYv4vZMzeSdN9994Wvd+7cWf3791dCQoImTZoU3t+/f3917txZ+/fvD++Li4tTmza1U1ddXa3y8nIFg0H1799fO3fuDI974403lJSUpAkTJoT3tW/fXvfff3+9Oj755BMVFxcrMzNT5eXlOnLkiI4cOaKKigqNHj1a7777rmpqas778wfqWO2Fb775RnfccYfi4+O1ePHiyCcEMclaH4waNUpvvvmm1q9frxkzZujCCy9URUVF0yfGh2L2DcXt27dXjx496u3r1KmTevfufdZ7Vzp16qSjR4+Gb9fU1GjZsmXKz89XSUmJqqurw/d169YtfP3AgQNKTU09698785RgcXGxJGnKlCnnrPf48ePq0qVLhM8OiJzVXqiurtbPfvYzffrpp3r99df1ox/96Acfg9hlsQ8uuugiXXTRRZKk22+/Xbm5uRozZoyKi4vNv6E4ZsNN27Ztm7TfcZzw9dzcXM2fP1/33HOPFi1apK5du6pNmzbKzs6O6gxL3WPy8vI0aNCgBsfwE1BoKVZ74f7779fvfvc7FRYWKj09vcm1ILZY7YPvu/3225WTk6NXXnlF06dPb/Lj/SRmw01zvPjiixo1apRWrlxZb/+xY8fq/aRGcnKyPv30UzmOUy+pf/755/Uel5qaKklKTEzUTTfd1IKVA+eXV3th7ty5Kigo0NKlS3XXXXdF/e8AkfBqH5ypsrJSUu1ZH+ti+j030Wrbtm291C5J69evV1lZWb19Y8eOVVlZmV599dXwvtOnT+vZZ5+tN+6aa65RamqqlixZolOnTp31/3311VfnsXrg/PFiL+Tl5WnJkiWaN2+eZs+e3ZSnA0TFa31w5MiRs+qRpBUrVkiShgwZ0vgTMoAzN1EYP368Fi5cqGnTpmno0KHas2ePCgsLdemll9YbN336dC1fvlx33XWXZs+erYsvvliFhYVq3769JIWTe5s2bbRixQplZGRo4MCBmjZtmpKSklRWVqYtW7YoMTFRGzZsaLSmDRs2hH+nQlVVlXbv3q1f/vKXkqQJEyboxz/+8fmeBsBzvfDyyy/r4Ycf1uWXX66/+7u/0/PPP1/v/jFjxoTfgwCcL17rg+eff17PPPOMbrnlFl166aU6efKkNm7cqDfffFM//elPY+LbtISbKMybN08VFRVas2aN1q1bp8GDB+u1117To48+Wm9cMBjU5s2bNWvWLC1btkzBYFCTJ0/W0KFDddttt4U/oSVp5MiR+vDDD7Vo0SItX75cp06dUq9evXT99ddH9L3Rl156Sf/zP/8Tvv3xxx/r448/liT17t2bcIMW4bVeqAv4xcXFuvvuu8+6f8uWLYQbnHde64Nhw4Zp69ateuGFF3T48GFdcMEF6t+/v5588knNmjWrRebAawJOQ+eu0KKWLl2qOXPm6IsvvlBSUpLb5QCuoRcA+qAlEG5aWGVlpeLj48O3T58+rbS0NFVXV+uzzz5zsTKgddELAH3QWvi2VAu79dZb1bdvXw0aNEjHjx/X888/r3379qmwsNDt0oBWRS8A9EFrIdy0sLFjx2rFihUqLCxUdXW1rrjiCq1du1Z33nmn26UBrYpeAOiD1sK3pQAAgCn8nhsAAGAK4QYAAJhCuAEAAKZE/IbiM/+KKZqAdzVFzfHg5AVEL6D1ea0XeE2AGyJ9mzBnbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKb4MN/369VNWVpYKCgq0e/duVVVVyXEc5eTkuF2af6yXNFJSF0kJkq6W9LikKhdrQnRYy+gxdybwmhA9q3N3gdsFRGPmzJnKzs52uwz/ypa0TLWrny4pKGmzpEckbZD0B0nxbhWHJskWaxmtbDF3RvCaED2rc+fLMzd79+5VXl6eMjMzNWDAAK1evdrtkvyjSLUH9KCkbZI2SnpJUrGkqyS9L2m+W8WhSYrEWkarSMydIbwmRM/q3PnyzM3KlSvr3a6pqXGpEh/KDW0flTT4e/u7S8qXdKOk5ao9sHdq3dLQRKxl9Jg7U3hNiJ7VufPlmRtEqUzS9tD1zAbuHyapj6RvJf2+tYpCVFjL6DF3gHmEm1jycWjbVdIl5xgz5Iyx8CbWMnrMHWAe4SaWlIS2fRsZ0+eMsfAm1jJ6zB1gHuEmlpwMbRMaGRMMbU+0cC1oHtYyeswdYB7hBgAAmEK4iSUdQ9uKRsacCm0TW7gWNA9rGT3mDjCPcBNLUkLb0kbG1N2X0sgYuC8ltGUtmy4ltGXuALMIN7EkLbQt17nfKLkjtB18jvvhDaxl9Jg7wDzCTSzpLena0PU1Ddz/vmq/Yo2TdHNrFYWosJbRY+4A8wg3sWZeaLtY0s7v7S+X9EDoepb4rax+wFpGj7kDTAs4juNENDAQaOlaIpaWlqb8/Pzw7dTUVPXo0UOlpaUqKysL7584caIOHTrkRon1RTTDrWi2pF9LulDSaNX+SOxbko5JukHSm/LMHwx0PDd5UkDe6QU/raXn+GzuvNYLvCbY4Le5izCySE6EVPsS7YmPESNGRFRzcnKy67Wq9pjkvcs6ORouR4lyFC9HV8rRYjn61vXK6l28yO05Oevik7X05MVHc+c1rh9X/fya4KEPv81dpHx55sZ3IpphNMTx4OR56swNYobXeoHXBLghwsjCe24AAIAthBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAplzgdgGA7wTcLgAxyXG7AMA/OHMDAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEzxZbjp16+fsrKyVFBQoN27d6uqqkqO4ygnJ8ft0vxjvaSRkrpISpB0taTHJVW5WBOahD6IHnNnC+sZPatzd4HbBURj5syZys7OdrsM/8qWtEy1q58uKShps6RHJG2Q9AdJ8W4Vh0jRB9Fj7mxhPaNnde58eeZm7969ysvLU2ZmpgYMGKDVq1e7XZJ/FKk22AQlbZO0UdJLkoolXSXpfUnz3SoOTUEfRI+5s4X1jJ7VufPlmZuVK1fWu11TU+NSJT6UG9o+Kmnw9/Z3l5Qv6UZJy1UbcDq1bmloGvogesydLaxn9KzOnS/P3CBKZZK2h65nNnD/MEl9JH0r6fetVRQAAOcX4SaWfBzadpV0yTnGDDljLAAAPkO4iSUloW3fRsb0OWMsAAA+Q7iJJSdD24RGxgRD2xMtXAsAAC2EcAMAAEwh3MSSjqFtRSNjToW2iS1cCwAALYRwE0tSQtvSRsbU3ZfSyBgAADyMcBNL0kLbcp37DcM7QtvB57gfAACPI9zEkt6Srg1dX9PA/e+r9sxNnKSbW6soAADOL8JNrJkX2i6WtPN7+8slPRC6niV+OzEAwLcCjuM4EQ0MBFq6loilpaUpPz8/fDs1NVU9evRQaWmpysrKwvsnTpyoQ4cOuVFifRHNcCuaLenXki6UNFq1Pxr+lqRjkm6Q9KY884czHc9Nnnd6wXd94CF+nLsID9Wtxit9IPlzPb3Cb3MXaR/48m9LJSYm6ic/+clZ+/v06aM+ffqEb8fFxbVmWf6xTLUh5ilJWyVVSUpV7d+bmiOpnXulIXL0QfSYO1tYz+hZnTtfnrnxHW99weUrnLkBanHmBoi8D3jPDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMCUgOM4jttFAAAAnC+cuQEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYIr5cLNq1SoFAgHt2LHD7VJazfvvv69AIKBAIKAjR464XQ48IlZ6oe5z/8yPxYsXu10aPCBW+kCSDh8+rOnTpyspKUnt27dXSkqK7r33XrfLahUXuF0Azq+amhrNmjVLCQkJqqiocLscwBVjxozR5MmT6+1LS0tzqRqg9ZWWluqGG26QJM2YMUNJSUn68ssv9dFHH7lcWesg3Bjz3//93yotLdV9992nZcuWuV0O4Ip+/frpn/7pn9wuA3DN9OnTdcEFF2j79u3q1q2b2+W0OvPflmrI1KlTFQwGdfDgQY0fP17BYFBJSUl66qmnJEl79uxRenq6EhISlJycrDVr1tR7/Ndff62HHnpIV111lYLBoBITE5WRkaFdu3ad9X8dOHBAEyZMUEJCgnr27Kk5c+Zo48aNCgQCevvtt+uN3bZtm8aNG6dOnTqpQ4cOGjFihD744IOIn9fXX3+tf/u3f9PChQvVuXPnJs8LYo/VXpCkyspKnT59umkTgphkrQ/27dun119/XXPnzlW3bt10+vRpVVVVRT9BPhST4UaSqqurlZGRoT59+ujxxx9XSkqKsrKytGrVKo0bN05DhgzRY489po4dO2ry5MkqKSkJP3b//v0qKirS+PHj9eSTT2ru3Lnas2ePRowYoS+//DI8rqKiQunp6dq0aZMefPBB5eTkaOvWrXrkkUfOqmfz5s0aPny4Tpw4oQULFig3N1fHjh1Tenp6xKcR58+fr169emn69OnNnyDEDIu9sGrVKiUkJCg+Pl5XXHHFWS9GwJks9cGmTZskSRdddJFGjx6t+Ph4xcfHKyMjQ3/5y1/Oz4R5nWNcQUGBI8nZvn17eN+UKVMcSU5ubm5439GjR534+HgnEAg4a9euDe/ft2+fI8lZsGBBeN/p06ed6urqev9PSUmJExcX5yxcuDC874knnnAkOUVFReF9lZWVzoABAxxJzpYtWxzHcZyamhrn8ssvd8aOHevU1NSEx37zzTfOJZdc4owZM+YHn+euXbuctm3bOhs3bnQcx3EWLFjgSHK++uqrH3wsYkOs9MLQoUOdpUuXOq+88orz9NNPO1deeaUjycnPz//hSYJ5sdAHDz74oCPJ6datmzNu3Dhn3bp1Tl5enhMMBp3U1FSnoqIissnysZg9cyNJ9913X/h6586d1b9/fyUkJGjSpEnh/f3791fnzp21f//+8L64uDi1aVM7ddXV1SovL1cwGFT//v21c+fO8Lg33nhDSUlJmjBhQnhf+/btdf/999er45NPPlFxcbEyMzNVXl6uI0eO6MiRI6qoqNDo0aP17rvvqqamptHn8uCDDyojI0N///d/H91kIKZZ6oUPPvhAs2fP1oQJEzRjxgz96U9/0pVXXql58+apsrIyuglCTLDSB6dOnZIk9erVS6+99pomTZqkhx56SM8++6z+/Oc/x8SZzJh9Q3H79u3Vo0ePevs6deqk3r17KxAInLX/6NGj4ds1NTVatmyZ8vPzVVJSourq6vB933/j1oEDB5SamnrWv3fZZZfVu11cXCxJmjJlyjnrPX78uLp06dLgfevWrdPWrVu1d+/ecz4eOBdLvdCQdu3aKSsrKxx0hg0bFvFjETss9UF8fLwkadKkSeHQJUl33HGH7r77bm3durVekLMoZsNN27Ztm7TfcZzw9dzcXM2fP1/33HOPFi1apK5du6pNmzbKzs7+wa8qG1L3mLy8PA0aNKjBMcFg8JyPnzt3ru644w61a9cu/P3UY8eOSar9ccDvvvtOP/rRj5pcF2KDpV44lz59+kiqfeMn0BBLfVB3vL/ooovq7W/btq26detWL5hZFbPhpjlefPFFjRo1SitXrqy3/9ixY+revXv4dnJysj799FM5jlMvqX/++ef1HpeamipJSkxM1E033dTkekpLS7VmzZoGTzUOHjxYV199tT755JMm/7vAD/FaL5xL3bcQzvzKHDgfvNYH11xzjSSprKys3v7vvvtOR44ciYk+iOn33ESrbdu29VK7JK1fv/6sT6SxY8eqrKxMr776anjf6dOn9eyzz9Ybd8011yg1NVVLliwJf6/0+7766qtG63n55ZfP+rjzzjslSatXr9avfvWrJj0/IFJe64WG7j958qSWLl2q7t27hw/6wPnktT4YOXKkevbsqcLCwnq/DmHVqlWqrq7WmDFjIn5ufsWZmyiMHz9eCxcu1LRp0zR06FDt2bNHhYWFuvTSS+uNmz59upYvX6677rpLs2fP1sUXX6zCwkK1b99eksLJvU2bNlqxYoUyMjI0cOBATZs2TUlJSSorK9OWLVuUmJioDRs2nLOeW2655ax9dWdqMjIy6n3lAJxPXuuFp556SkVFRfrpT3+qvn376q9//auee+45HTx4UL/5zW/Url27lpsMxCyv9UFcXJzy8vI0ZcoUDR8+XHfffbcOHjyoZcuW6cYbb9Stt97acpPhEYSbKMybN08VFRVas2aN1q1bp8GDB+u1117To48+Wm9cMBjU5s2bNWvWLC1btkzBYFCTJ0/W0KFDddttt4U/oaXapP3hhx9q0aJFWr58uU6dOqVevXrp+uuv5/fWwLO81gs33HCDtm7dqhUrVqi8vFwJCQm67rrr9Nxzzyk9Pb1F5gDwWh9I0uTJk9WuXTstXrxYc+fOVefOnTV9+nTl5uae831ElgScM8+locUtXbpUc+bM0RdffKGkpCS3ywFcQy8A9EFLINy0sMrKyvCP5Um1319NS0tTdXW1PvvsMxcrA1oXvQDQB62Fb0u1sFtvvVV9+/bVoEGDdPz4cT3//PPat2+fCgsL3S4NaFX0AkAftBbCTQsbO3asVqxYocLCQlVXV+uKK67Q2rVrwz/NBMQKegGgD1oL35YCAACm8HtuAACAKYQbAABgCuEGAACYEvEbigMK/PAg4Dxz5L23hJ35F33RBN5bTt/wWi/QB83graX0lUj7gDM3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUf4eb9ZJGSuoiKUHS1ZIel1TlYk1+wdz5Xr9+/ZSVlaWCggLt3r1bVVVVchxHOTk5bpfmH/SBCfTCeWCsFy5wu4CoZUtaptpnkC4pKGmzpEckbZD0B0nxbhXncdli7gyYOXOmsrOz3S7Dv7JFHxhBLzRTtsz1gj/P3BSpdiGCkrZJ2ijpJUnFkq6S9L6k+W4V53FFYu6M2Lt3r/Ly8pSZmakBAwZo9erVbpfkH0WiDwyhF5qhSDZ7wYmQvHS5Vo4kR79s4L73QvfFydEx1yv13sVnc+dFkjz5UVBQ4DiO4+Tk5Lheyzk/vHLxWR/I8V4vuP655Ode8NLFZ70QKf+duSmTtD10PbOB+4dJ6iPpW0m/b62ifIK5A+gDoI7hXvBfuPk4tO0q6ZJzjBlyxljUYu4A+gCoY7gX/BduSkLbvo2M6XPGWNRi7gD6AKhjuBf8F25OhrYJjYwJhrYnWrgWv2HuAPoAqGO4F/wXbgAAABrhv3DTMbStaGTMqdA2sYVr8RvmDqAPgDqGe8F/4SYltC1tZEzdfSmNjIlFKaEtc4dYlhLa0geIdSmhrcFe8F+4SQtty3XuNzjtCG0Ht3w5vsLcAfQBUMdwL/gv3PSWdG3o+poG7n9ftUkzTtLNrVWUTzB3AH0A1DHcC/4LN5I0L7RdLGnn9/aXS3ogdD1LUqfWLMonmDuAPgDqGO2FQOjXaP/wQAVaupammS3p15IulDRatT/K9pakY5JukPSmfPeHvlqNj+bOUUSfnq0qEPBGL6SlpSk/Pz98OzU1VT169FBpaanKysrC+ydOnKhDhw65UeLZvLScPuoDyXu94JU+kHzYC95aSl/1QsR9EPHfEfHiZZ0cDZejRDmKl6Mr5WixHH3remXev/hk7rxIbv9dmtDHiBEjIqo3OTnZ9VrDH167+KQP5HivF1z/XPJzL3jx4pNeiJR/z9wgJjiK6NOzVXnpK1bf8d5y+obXeoE+aAZvLaWvRNoH/nzPDQAAwDkQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCn+DjfrJY2U1EVSgqSrJT0uqcrFmvyCufO9fv36KSsrSwUFBdq9e7eqqqrkOI5ycnLcLs0/6AMT6IXzwFgvXOB2AVHLlrRMtc8gXVJQ0mZJj0jaIOkPkuLdKs7jssXcGTBz5kxlZ2e7XYZ/ZYs+MIJeaKZsmesFf565KVLtQgQlbZO0UdJLkoolXSXpfUnz3SrO44rE3Bmxd+9e5eXlKTMzUwMGDNDq1avdLsk/ikQfGEIvNEORbPaCEyF56XKtHEmOftnAfe+F7ouTo2OuV+q9i8/mzoskefKjoKDAcRzHycnJcb2Wc3545eKzPpDjvV5w/XPJz73gpYvPeiFS/jtzUyZpe+h6ZgP3D5PUR9K3kn7fWkX5BHMH0AdAHcO94L9w83Fo21XSJecYM+SMsajF3AH0AVDHcC/4L9yUhLZ9GxnT54yxqMXcAfQBUMdwL/gv3JwMbRMaGRMMbU+0cC1+w9wB9AFQx3Av+C/cAAAANMJ/4aZjaFvRyJhToW1iC9fiN8wdQB8AdQz3gv/CTUpoW9rImLr7UhoZE4tSQlvmDrEsJbSlDxDrUkJbg73gv3CTFtqW69xvcNoR2g5u+XJ8hbkD6AOgjuFe8F+46S3p2tD1NQ3c/75qk2acpJtbqyifYO4A+gCoY7gX/BduJGleaLtY0s7v7S+X9EDoepakTq1ZlE8wdwB9ANQx2guB0K/R/uGBCrR0LU0zW9KvJV0oabRqf5TtLUnHJN0g6U357g99tRofzZ2jiD49W1Ug4I1eSEtLU35+fvh2amqqevToodLSUpWVlYX3T5w4UYcOHXKjxLN5aTl91AeS93rBK30g+bAXvLWUvuqFiPsg4r8j4sXLOjkaLkeJchQvR1fK0WI5+tb1yrx/8cnceZHc/rs0oY8RI0ZEVG9ycrLrtYY/vHbxSR/I8V4vuP655Ode8OLFJ70QKf+euUFMcBTRp2er8tJXrL7jveX0Da/1An3QDN5aSl+JtA/8+Z4bAACAcyDcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAwhXADAABMIdwAAABTCDcAAMAUwg0AADAl4DiO43YRAAAA5wtnbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmGI+3KxatUqBQEA7duxwu5QWU/ccz/VRWFjodonwgFjoBUk6fvy4Hn74YV1++eWKj49XcnKy7r33Xh08eNDt0uABsdIHhw8f1rRp09SzZ0/Fx8dr8ODBWr9+vdtltZoL3C4AzTd8+HD95je/OWv/r371K+3atUujR492oSqg9dXU1GjMmDH69NNP9cADD6hfv376/PPPlZ+fr40bN+r//u//1LFjR7fLBFrUiRMnNGzYMB0+fFizZ89Wr1699Nvf/laTJk1SYWGhMjMz3S6xxRFuDLj00kt16aWX1ttXWVmpBx54QOnp6erVq5dLlQGt649//KO2b9+u5cuX61/+5V/C+/v376977rlHmzZt0sSJE12sEGh5//Vf/6XPP/9cb731ltLT0yVJM2fO1E9+8hP967/+q26//Xa1a9fO5SpblvlvSzVk6tSpCgaDOnjwoMaPH69gMKikpCQ99dRTkqQ9e/YoPT1dCQkJSk5O1po1a+o9/uuvv9ZDDz2kq666SsFgUImJicrIyNCuXbvO+r8OHDigCRMmKCEhQT179tScOXO0ceNGBQIBvf322/XGbtu2TePGjVOnTp3UoUMHjRgxQh988EFUz3HDhg06efKk/vEf/zGqxyM2WOuFEydOSJIuuuiievsvvvhiSVJ8fHzEc4PYYa0P3nvvPfXo0SMcbCSpTZs2mjRpkg4dOqR33nknilnyl5gMN5JUXV2tjIwM9enTR48//rhSUlKUlZWlVatWady4cRoyZIgee+wxdezYUZMnT1ZJSUn4sfv371dRUZHGjx+vJ598UnPnztWePXs0YsQIffnll+FxFRUVSk9P16ZNm/Tggw8qJydHW7du1SOPPHJWPZs3b9bw4cN14sQJLViwQLm5uTp27JjS09P10UcfNfn5FRYWKj4+Xrfeemt0E4SYYakXhgwZooSEBM2fP1+bN29WWVmZ3nnnHT388MO69tprddNNN52/iYMplvrg22+/bTDId+jQQZL0pz/9Kdpp8g/HuIKCAkeSs3379vC+KVOmOJKc3Nzc8L6jR4868fHxTiAQcNauXRvev2/fPkeSs2DBgvC+06dPO9XV1fX+n5KSEicuLs5ZuHBheN8TTzzhSHKKiorC+yorK50BAwY4kpwtW7Y4juM4NTU1zuWXX+6MHTvWqampCY/95ptvnEsuucQZM2ZMk55zeXm5065dO2fSpElNehxsi5Ve+N3vfudcfPHFjqTwx9ixY52TJ0/+8CTBvFjog1mzZjlt2rRx/vKXv9Tb/7Of/cyR5GRlZTX6eAti9syNJN13333h6507d1b//v2VkJCgSZMmhff3799fnTt31v79+8P74uLi1KZN7dRVV1ervLxcwWBQ/fv3186dO8Pj3njjDSUlJWnChAnhfe3bt9f9999fr45PPvlExcXFyszMVHl5uY4cOaIjR46ooqJCo0eP1rvvvquampqIn9eLL76o7777jm9JIWKWeqFHjx5KS0vTf/zHf6ioqEj//u//rvfee0/Tpk2LbnIQM6z0wX333ae2bdtq0qRJ2rp1q/785z/rP//zP/Xyyy9Lqn1PpnUx+4bi9u3bq0ePHvX2derUSb1791YgEDhr/9GjR8O3a2pqtGzZMuXn56ukpETV1dXh+7p16xa+fuDAAaWmpp7171122WX1bhcXF0uSpkyZcs56jx8/ri5dukT03AoLC9W1a1dlZGRENB6xzVIv7N+/X6NGjdLq1at12223SZL+4R/+QSkpKZo6dapef/11+gINstQHP/7xj7VmzRrNmDFDN9xwgySpV69eWrp0qWbOnKlgMHjOf9eKmA03bdu2bdJ+x3HC13NzczV//nzdc889WrRokbp27ao2bdooOzu7SWdY6tQ9Ji8vT4MGDWpwTKSfjAcPHtR7772nf/7nf9aFF17Y5FoQeyz1wqpVq3T69GmNHz++3v66r5Q/+OADwg0aZKkPJOn222/XhAkTtGvXLlVXV2vw4MHhNyz369evyTX5TcyGm+Z48cUXNWrUKK1cubLe/mPHjql79+7h28nJyfr000/lOE69pP7555/Xe1xqaqokKTExsdlveHzhhRfkOA7fkkKr8FovHD58WI7j1PvKWZKqqqokSX/729+a/G8CP8RrfVCnXbt2uvbaa8O3N23aJEkx8cb6mH7PTbTatm1bL7VL0vr161VWVlZv39ixY1VWVqZXX301vO/06dN69tln64275pprlJqaqiVLlujUqVNn/X9fffVVxLWtWbNGffv21bBhwyJ+DBAtr/VCv3795DiOfvvb39bb/8ILL0iS0tLSfvhJAU3ktT5oSHFxsZ555hmNHz+eMzdo2Pjx47Vw4UJNmzZNQ4cO1Z49e1RYWHjWL9KbPn26li9frrvuukuzZ8/WxRdfrMLCQrVv316Swsm9TZs2WrFihTIyMjRw4EBNmzZNSUlJKisr05YtW5SYmKgNGzb8YF179+7V7t279eijj571PV2gJXitF6ZOnaolS5Zo+vTp+vjjjzVw4EDt3LlTK1as0MCBA/kFfmgRXusDSbriiit0xx13qG/fviopKdHTTz+trl276plnnmmZSfAYwk0U5s2bp4qKCq1Zs0br1q3T4MGD9dprr+nRRx+tNy4YDGrz5s2aNWuWli1bpmAwqMmTJ2vo0KG67bbbwp/QkjRy5Eh9+OGHWrRokZYvX65Tp06pV69euv766zV9+vSI6qr7G1Kx8Ku14Q1e64Vu3bppx44d+vnPf64NGzbomWeeUbdu3XTPPfcoNzfX/G9lhTu81geSdPXVV6ugoECHDx9W9+7dNWnSJP3iF79Qz549z/vz96KAc+a5NLS4pUuXas6cOfriiy+UlJTkdjmAa+gFgD5oCYSbFlZZWVnvN0WePn1aaWlpqq6u1meffeZiZUDrohcA+qC18G2pFnbrrbeqb9++GjRokI4fP67nn39e+/btC38LCYgV9AJAH7QWwk0LGzt2rFasWKHCwkJVV1friiuu0Nq1a3XnnXe6XRrQqugFgD5oLXxbCgAAmMLvuQEAAKYQbgAAgCmEGwAAYErEbyjmN942A+9qiprjwckLiF6IGlMXNa+9PZLXhGbw1lL6SqSvCZy5AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACm+DLc9OvXT1lZWSooKNDu3btVVVUlx3GUk5Pjdmn+sV7SSEldJCVIulrS45KqXKwJ0WEtm4xjiC2s53lg7DhygdsFRGPmzJnKzs52uwz/ypa0TLWrny4pKGmzpEckbZD0B0nxbhWHJskWaxkFjiG2sJ7NlC1zxxFfnrnZu3ev8vLylJmZqQEDBmj16tVul+QfRar9JA5K2iZpo6SXJBVLukrS+5Lmu1UcmqRIrGWUOIbYwno2Q5FMHkd8eeZm5cqV9W7X1NS4VIkP5Ya2j0oa/L393SXlS7pR0nLVfjJ3at3S0ESsZdQ4htjCejaD0eOIL8/cIEplkraHrmc2cP8wSX0kfSvp961VFKLCWgJoLsPHEcJNLPk4tO0q6ZJzjBlyxlh4E2sJoLkMH0cIN7GkJLTt28iYPmeMhTexlgCay/BxhHATS06GtgmNjAmGtidauBY0D2sJoLkMH0cINwAAwBTCTSzpGNpWNDLmVGib2MK1oHlYSwDNZfg4QriJJSmhbWkjY+ruS2lkDNyXEtqylgCilRLaGjyOEG5iSVpoW65zvzlsR2g7+Bz3wxtYSwDNZfg4QriJJb0lXRu6vqaB+99XbUqPk3RzaxWFqLCWAJrL8HGEcBNr5oW2iyXt/N7+ckkPhK5nyVe/iTJmsZYAmsvocSTgOI4T0cBAoKVriVhaWpry8/PDt1NTU9WjRw+VlpaqrKwsvH/ixIk6dOiQGyXWF9EMt6LZkn4t6UJJo1X7Y4BvSTom6QZJb8ozfyTN8dzkSQF5pxf8tJaS5JWp890xRFKEh+pWw2tCM3hrKX11HIn4NcGJkGqXwxMfI0aMiKjm5ORk12sNrYX3LuvkaLgcJcpRvBxdKUeL5ehb1yurd/Eit+fkrItP1lKO3O9Fvx5D5L1ecHs+fL2eXrz45DgSKV+eufGdiGYYDXE8OHmeOnPjN0xd1CI8VLcaXhOawVtL6SuRvibwnhsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJjiy3DTr18/ZWVlqaCgQLt371ZVVZUcx1FOTo7bpfnHekkjJXWRlCDpakmPS6pysSZEh7VsMo4htrCe54Gx48gFbhcQjZkzZyo7O9vtMvwrW9Iy1a5+uqSgpM2SHpG0QdIfJMW7VRyaJFusZRQ4htjCejZTtswdR3x55mbv3r3Ky8tTZmamBgwYoNWrV7tdkn8UqfaTOChpm6SNkl6SVCzpKknvS5rvVnFokiKxllHiGGIL69kMRTJ5HPHlmZuVK1fWu11TU+NSJT6UG9o+Kmnw9/Z3l5Qv6UZJy1X7ydypdUtDE7GWUeMYYgvr2QxGjyO+PHODKJVJ2h66ntnA/cMk9ZH0raTft1ZRiAprCaC5DB9HCDex5OPQtqukS84xZsgZY+FNrCWA5jJ8HCHcxJKS0LZvI2P6nDEW3sRaAmguw8cRwk0sORnaJjQyJhjanmjhWtA8rCWA5jJ8HCHcAAAAUwg3saRjaFvRyJhToW1iC9eC5mEtATSX4eMI4SaWpIS2pY2MqbsvpZExcF9KaMtaAohWSmhr8DhCuIklaaFtuc795rAdoe3gc9wPb2AtATSX4eMI4SaW9JZ0bej6mgbuf1+1KT1O0s2tVRSiwloCaC7DxxHCTayZF9oulrTze/vLJT0Qup4lX/0mypjFWgJoLqPHkYDjOE5EAwOBlq4lYmlpacrPzw/fTk1NVY8ePVRaWqqysrLw/okTJ+rQoUNulFhfRDPcimZL+rWkCyWNVu2PAb4l6ZikGyS9Kc/8kTTHc5MnBeSdXvDTWkqSV6bOd8cQSREeqlsNrwnN4K2l9NVxJOLXBCdCql0OT3yMGDEiopqTk5NdrzW0Ft67rJOj4XKUKEfxcnSlHC2Wo29dr6zexYvcnpOzLj5ZSzlyvxf9egyR93rB7fnw9Xp68eKT40ikfHnmxncimmE0xPHg5HnqzI3fMHVRi/BQ3Wp4TWgGby2lr0T6msB7bgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmEG4AAIAphBsAAGAK4QYAAJhCuAEAAKYQbgAAgCmEGwAAYArhBgAAmEK4AQAAphBuAACAKYQbAABgCuEGAACYQrgBAACmBBzHcdwuAgAA4HzhzA0AADCFcAMAAEwh3AAAAFMINwAAwBTCDQAAMIVwAwAATCHcAAAAUwg3AADAFMINAAAw5f8B7cSewWAv2v8AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 600x600 with 9 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Initialize generator for evaluation\n",
    "generator_for_evaluation = Generator(input_size=2, output_size=4)\n",
    "generator_for_evaluation.load_state_dict(\n",
    "    torch.load(path / \"resources/generator_model.pth\")\n",
    ")  # Load trained model\n",
    "generator_for_evaluation.eval()\n",
    "\n",
    "# Generate images\n",
    "with torch.no_grad():\n",
    "    noise = torch.randn(9, 2)\n",
    "    generated_images = generator_for_evaluation(noise).detach().cpu().numpy()\n",
    "\n",
    "# Plot images in a 3 by 3 grid\n",
    "generated_images = create_bars_and_stripes_dataset(9)\n",
    "plot_nine_images(generated_images)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66344754c5a37cb7",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "\n",
    "## 3 Quantum hybrid network implementation\n",
    "\n",
    "In this section we will define a quantum generator circuit and integrate it into a hybrid quantum-classical GAN model. We will then train the QGAN model and evaluate its performance."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b51438a9-ea7f-490d-8fac-cb51d22e377b",
   "metadata": {},
   "source": [
    "### 3.1 Defining the quantum GAN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65a77441-20a0-4220-b610-0c0cf6c0dc53",
   "metadata": {},
   "source": [
    "#### 3.1.1 Defining the quantum generator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5803e3a-0fdc-4606-b532-24f3c2132e73",
   "metadata": {},
   "source": [
    "We define the three components of the quantum layer - this is where the quantum network architect's creativity comes into play!\n",
    "\n",
    "The design we choose:\n",
    "1. Data encoding - we use take a `datum_angle_encoding` that encodes $n$ data points on $n$ qubits.\n",
    "2. A variational ansatz - we combine RY and RZZ gates.\n",
    "3. Classical post-process - we take the vector $(p_1, p_2, \\dots, p_n)$, with $p_i$ being the probability to measure 1 on the $i$-th qubit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "25cef3df-6932-490f-b42e-5e3e324ce974",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:45.117431Z",
     "iopub.status.busy": "2024-05-07T14:48:45.116259Z",
     "iopub.status.idle": "2024-05-07T14:48:47.029273Z",
     "shell.execute_reply": "2024-05-07T14:48:47.028620Z"
    }
   },
   "outputs": [],
   "source": [
    "from typing import List\n",
    "\n",
    "from classiq import IDENTITY, RX, RY, RZ, RZZ, CArray, CReal, QArray, if_, qfunc, repeat\n",
    "from classiq.applications.qnn.types import SavedResult\n",
    "from classiq.qmod.symbolic import floor, pi\n",
    "\n",
    "\n",
    "@qfunc\n",
    "def datum_angle_encoding(data_in: CArray[CReal], qbv: QArray) -> None:\n",
    "\n",
    "    def even_case(exe_params: CArray[CReal], qbv: QArray) -> None:\n",
    "        repeat(\n",
    "            count=exe_params.len,\n",
    "            iteration=lambda index: RX(pi * data_in[index], qbv[index]),\n",
    "        )\n",
    "        repeat(\n",
    "            count=exe_params.len,\n",
    "            iteration=lambda index: RZ(pi * data_in[index], qbv[index]),\n",
    "        )\n",
    "\n",
    "    def odd_case(data_in: CArray[CReal], qbv: QArray) -> None:\n",
    "\n",
    "        even_case(data_in, qbv)\n",
    "        RX(pi * data_in[data_in.len - 1], target=qbv[data_in.len])\n",
    "\n",
    "    if_(\n",
    "        condition=data_in.len - 2 * (floor(data_in.len / 2)) == 0,\n",
    "        then=lambda: even_case(data_in, qbv),\n",
    "        else_=lambda: odd_case(data_in, qbv),\n",
    "    )\n",
    "\n",
    "\n",
    "@qfunc\n",
    "def my_ansatz(weights: CArray[CReal], qbv: QArray) -> None:\n",
    "\n",
    "    repeat(\n",
    "        count=qbv.len,\n",
    "        iteration=lambda index: RY(weights[index], qbv[index]),\n",
    "    )\n",
    "    repeat(\n",
    "        count=qbv.len - 1,\n",
    "        iteration=lambda index: RZZ(weights[qbv.len + index], qbv[index : index + 2]),\n",
    "    )\n",
    "    if_(\n",
    "        condition=qbv.len > 2,\n",
    "        then=lambda: RZZ(weights[-1], qbv[0:2]),\n",
    "        else_=lambda: IDENTITY(qbv[0]),\n",
    "    )\n",
    "\n",
    "\n",
    "def my_post_process(result: SavedResult, num_qubits, num_shots) -> torch.Tensor:\n",
    "\n",
    "    res = result.value\n",
    "    yvec = [\n",
    "        (res.counts_of_qubits(k)[\"1\"] if \"1\" in res.counts_of_qubits(k) else 0)\n",
    "        / num_shots\n",
    "        for k in range(num_qubits)\n",
    "    ]\n",
    "\n",
    "    return torch.tensor(yvec)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3411251e74a31939",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Finally, we define the quantum model with its hyperparameters as our `main`  quantum function, and synthesize it into a quantum program."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "45009ccc2f91a0fd",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:02:05.014020Z",
     "start_time": "2024-02-21T14:01:59.219051Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:47.033820Z",
     "iopub.status.busy": "2024-05-07T14:48:47.033432Z",
     "iopub.status.idle": "2024-05-07T14:48:49.634986Z",
     "shell.execute_reply": "2024-05-07T14:48:49.634272Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Opening: https://platform.classiq.io/circuit/dadac5e4-025a-4506-966f-fa197c68c756?version=0.41.0.dev39%2B79c8fd0855\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "from classiq import (\n",
    "    CArray,\n",
    "    CReal,\n",
    "    Output,\n",
    "    QArray,\n",
    "    allocate,\n",
    "    create_model,\n",
    "    qfunc,\n",
    "    show,\n",
    "    synthesize,\n",
    ")\n",
    "from classiq.execution import (\n",
    "    ExecutionPreferences,\n",
    "    execute_qnn,\n",
    "    set_quantum_program_execution_preferences,\n",
    ")\n",
    "\n",
    "NUM_SHOTS = 4096\n",
    "QLAYER_SIZE = 4\n",
    "num_qubits = int(np.ceil(QLAYER_SIZE))\n",
    "num_weights = 2 * num_qubits\n",
    "\n",
    "\n",
    "@qfunc\n",
    "def main(\n",
    "    input: CArray[CReal, QLAYER_SIZE],\n",
    "    weight: CArray[CReal, num_weights],\n",
    "    result: Output[QArray[num_qubits]],\n",
    ") -> None:\n",
    "\n",
    "    allocate(num_qubits, result)\n",
    "    datum_angle_encoding(data_in=input, qbv=result)\n",
    "\n",
    "    my_ansatz(weights=weight, qbv=result)\n",
    "\n",
    "\n",
    "model = create_model(main)\n",
    "quantum_program = synthesize(model)\n",
    "show(quantum_program)\n",
    "\n",
    "quantum_program = set_quantum_program_execution_preferences(\n",
    "    quantum_program, preferences=ExecutionPreferences(num_shots=NUM_SHOTS)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "403a05ded0c7cb89",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "**The resulting circuit is**:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e2a97ce-ac25-42d1-aa45-9b9eb6270635",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:02:05.016564Z",
     "start_time": "2024-02-21T14:02:05.011367Z"
    },
    "collapsed": false
   },
   "source": [
    "<center>\n",
    "<img src=\"https://docs.classiq.io/resources/qgan_circuit.png\" style=\"width:100%\">\n",
    "<figcaption align = \"middle\"> Hierarchical view of the quantum circuit for the QGAN generator. The circuit consists of an angle encoding layer, an ansatz layer, and a post-processing layer </figcaption>\n",
    "</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61309fe7-1b64-4c61-b500-d5c24c51ca61",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:02:05.027039Z",
     "start_time": "2024-02-21T14:02:05.015883Z"
    },
    "collapsed": false
   },
   "source": [
    "<center>\n",
    "<img src=\"https://docs.classiq.io/resources/qgan_angle_encoder.png\" style=\"width:100%\">\n",
    "<figcaption align = \"middle\"> Angle encoding layer consists of two consecutive noncommuting rotations encoding a single datum. </figcaption>\n",
    "</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0517eb8-4752-4ef7-ab45-0575a9051d94",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:02:05.034783Z",
     "start_time": "2024-02-21T14:02:05.021506Z"
    },
    "collapsed": false
   },
   "source": [
    "<center>\n",
    "<img src=\"https://docs.classiq.io/resources/qgan_anzats.png\" style=\"width:100%\">\n",
    "<figcaption align = \"middle\"> Ansatz layer including parametrized rotation followed by a pair-wise entangler via RZZ gate sequence. </figcaption>\n",
    "</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a899793-b2e7-4457-a00a-aece8c73dbaf",
   "metadata": {},
   "source": [
    "#### 3.1.2 Defining the hybrid network"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a9dd3d2c7fcce2",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "We define the network building blocks: the generator and discriminator in a hybrid network configuration with a quantum layer,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "fbe3f74e-6b3c-4d4d-a2a5-854a432dc06f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:02:07.266266Z",
     "start_time": "2024-02-21T14:02:05.058392Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:49.637942Z",
     "iopub.status.busy": "2024-05-07T14:48:49.637382Z",
     "iopub.status.idle": "2024-05-07T14:48:49.645560Z",
     "shell.execute_reply": "2024-05-07T14:48:49.644924Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "from classiq.applications.qnn import QLayer\n",
    "\n",
    "\n",
    "def create_net(*args, **kwargs) -> nn.Module:\n",
    "    class QGenerator(nn.Module):\n",
    "        def __init__(self, *args, **kwargs):\n",
    "            super().__init__()\n",
    "            self.flatten = nn.Flatten()\n",
    "            self.linear_1 = nn.Linear(4, 16)\n",
    "            self.linear_2 = nn.Linear(16, 32)\n",
    "            self.linear_3 = nn.Linear(32, 16)\n",
    "            self.linear_4 = nn.Linear(16, 4)\n",
    "            self.linear_5 = nn.Linear(2, 4)\n",
    "            self.activation_1 = nn.ReLU()\n",
    "            self.activation_2 = nn.Sigmoid()\n",
    "\n",
    "            self.qlayer = QLayer(\n",
    "                quantum_program,\n",
    "                execute_qnn,\n",
    "                post_process=lambda res: my_post_process(\n",
    "                    res, num_qubits=num_qubits, num_shots=NUM_SHOTS\n",
    "                ),\n",
    "                *args,\n",
    "                **kwargs,\n",
    "            )\n",
    "\n",
    "        def forward(self, x):\n",
    "            x = self.flatten(x)\n",
    "            x = self.linear_1(x)\n",
    "            x = self.activation_2(x)\n",
    "            x = self.linear_2(x)\n",
    "            x = self.activation_1(x)\n",
    "            x = self.linear_3(x)\n",
    "            x = self.activation_1(x)\n",
    "            x = self.linear_4(x)\n",
    "            x = self.activation_2(x)\n",
    "            x = self.qlayer(x)\n",
    "            x = torch.round(self.activation_2(x))\n",
    "            return x\n",
    "\n",
    "    return QGenerator(*args, **kwargs)\n",
    "\n",
    "\n",
    "class Discriminator(nn.Module):\n",
    "    def __init__(self, input_size=4, hidden_size=16):\n",
    "        super(Discriminator, self).__init__()\n",
    "        self.model = nn.Sequential(\n",
    "            nn.Linear(input_size, hidden_size // 2),  # Adjusted hidden layer size\n",
    "            nn.LeakyReLU(0.2),\n",
    "            nn.Linear(hidden_size // 2, hidden_size),  # Adjusted hidden layer size\n",
    "            nn.LeakyReLU(0.25),\n",
    "            nn.Dropout(0.3),\n",
    "            nn.Linear(hidden_size, 1),\n",
    "            nn.Sigmoid(),  # Sigmoid activation to output probabilities\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, 4)  # Flatten input for fully connected layers\n",
    "        return self.model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "991b4fc4-e482-4496-872f-0944ca373859",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:02:07.266266Z",
     "start_time": "2024-02-21T14:02:05.058392Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:49.648147Z",
     "iopub.status.busy": "2024-05-07T14:48:49.647712Z",
     "iopub.status.idle": "2024-05-07T14:48:49.654185Z",
     "shell.execute_reply": "2024-05-07T14:48:49.653617Z"
    }
   },
   "outputs": [],
   "source": [
    "q_gen = create_net()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8df1e990a6213e6",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 3.3 Training the QGAN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "475fe251-568b-4bd2-ae88-1ccacc22c485",
   "metadata": {},
   "source": [
    "We can use the training loops defined above for the classical GAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "da89ad96-a416-4cad-b5cf-cdcd5a80387a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:49.656526Z",
     "iopub.status.busy": "2024-05-07T14:48:49.656350Z",
     "iopub.status.idle": "2024-05-07T14:48:49.659780Z",
     "shell.execute_reply": "2024-05-07T14:48:49.659179Z"
    }
   },
   "outputs": [],
   "source": [
    "# Fixed noise for visualizing generated samples\n",
    "fixed_noise = torch.randn(4)\n",
    "\n",
    "\n",
    "def random_fake_data_for_qgan(batch_size, input_size):\n",
    "    return torch.bernoulli(torch.rand(batch_size, input_size))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88cb4abd27602131",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "The following cell generates an archive of the training process in the `q_logs` directory. We will also use Tensorboard to monitor the training in real time. It is possible to use an online version which is more convenient, but for the purpose of this notebook we will use the local version. Example of vizualization output that can be obtained from `tensorboard` is shown in the next figure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "4d4a0ca4c532ce2c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:02:07.307279Z",
     "start_time": "2024-02-21T14:02:07.278688Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:49.662294Z",
     "iopub.status.busy": "2024-05-07T14:48:49.661877Z",
     "iopub.status.idle": "2024-05-07T14:48:49.664776Z",
     "shell.execute_reply": "2024-05-07T14:48:49.664168Z"
    }
   },
   "outputs": [],
   "source": [
    "# ## generate tensorboard log directory\n",
    "# #\n",
    "# # log_dir = 'MY_LOG_DIR'\n",
    "# # if not os.path.exists(log_dir):\n",
    "# #     os.makedirs(log_dir)\n",
    "\n",
    "# # Launch tensorboard and generate the containing folder internally\n",
    "# %load_ext tensorboard\n",
    "# # %reload_ext tensorboard\n",
    "# %tensorboard --logdir='MY_LOG_DIR/q_logs'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b476da34-72cb-4b99-9c06-bf3368c4fd73",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:02:07.308905Z",
     "start_time": "2024-02-21T14:02:07.288999Z"
    },
    "collapsed": false
   },
   "source": [
    "<center>\n",
    "<img src=\"https://docs.classiq.io/resources/qgan_training.png\" style=\"width:100%\">\n",
    "<figcaption align = \"middle\"> Example of two training sessions. The first (green) line depicts a process in which the loss function estimator is raising, a clear indication that the learning session does not seem to converge to the desired result. The second (orange) shows the convergence of both. The two components compete to improve their performance. </figcaption>\n",
    "</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "378b85a5-a0ec-4c14-9fbd-6bc47b8f7b3b",
   "metadata": {},
   "source": [
    "***\n",
    "Since training can take long time to run, we take a pre-trained model, whose parameters are stored in `q_generator_trained_model.pth`. In addition, we take a smaller sample size of 250. (The pre-trained model was trained on 1000 samples). To train a randomly initialized QGAN change `num_samples` from 250 to 1000 for the data creation, and `num_epochs` from 1 to 10 in the training call `train_gan`.\n",
    "***"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "ada721bf-a347-49e6-983b-8911fa4ad0be",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:02:07.309344Z",
     "start_time": "2024-02-21T14:02:07.292707Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:49.667331Z",
     "iopub.status.busy": "2024-05-07T14:48:49.666912Z",
     "iopub.status.idle": "2024-05-07T14:48:49.682988Z",
     "shell.execute_reply": "2024-05-07T14:48:49.682357Z"
    }
   },
   "outputs": [],
   "source": [
    "# Create training dataset for qgan\n",
    "qgan_training_dataset = create_bars_and_stripes_dataset(\n",
    "    num_samples=250\n",
    "    # num_samples=1000\n",
    ")\n",
    "\n",
    "# Convert to PyTorch tensor\n",
    "qgan_tensor_dataset = torch.tensor(qgan_training_dataset, dtype=torch.float32)\n",
    "\n",
    "# Create a TensorDataset object\n",
    "qgan_tensor_dataset = TensorDataset(qgan_tensor_dataset)\n",
    "\n",
    "# Create a DataLoader object\n",
    "qgan_dataloader = DataLoader(qgan_tensor_dataset, batch_size=64, shuffle=True)\n",
    "\n",
    "q_generator = q_gen\n",
    "discriminator = Discriminator(input_size=4, hidden_size=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "0b473d12-1dc3-466f-991f-2f64e9f1988f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:02:07.309344Z",
     "start_time": "2024-02-21T14:02:07.292707Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:48:49.685167Z",
     "iopub.status.busy": "2024-05-07T14:48:49.684964Z",
     "iopub.status.idle": "2024-05-07T14:51:48.776942Z",
     "shell.execute_reply": "2024-05-07T14:51:48.776323Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/1], Step [1/4], Generator Loss: 0.9844, Discriminator Real Loss: 0.8843, Discriminator Fake Loss: 0.4672\n"
     ]
    }
   ],
   "source": [
    "checkpoint = torch.load(path / \"resources/q_generator_trained_model.pth\")\n",
    "q_generator.load_state_dict(checkpoint)\n",
    "\n",
    "train_gan(\n",
    "    generator=q_generator,\n",
    "    discriminator=discriminator,\n",
    "    dataloader=qgan_dataloader,\n",
    "    log_dir_name=\"q_logs\",\n",
    "    fixed_noise=fixed_noise,\n",
    "    random_fake_data_generator=lambda b_size: random_fake_data_for_qgan(b_size, 4),\n",
    "    num_epochs=1,\n",
    "    # num_epochs=10,\n",
    "    device=\"cpu\",\n",
    ")\n",
    "\n",
    "# Save trained generator model\n",
    "torch.save(q_generator.state_dict(), path / \"resources/q_generator_model_bs64.pth\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "423636ec0b58e28",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 3.3 Performance evaluation "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48d6033c-c920-44f4-b063-07d3183f5257",
   "metadata": {},
   "source": [
    "Finally, we can evaluate the performance of the QGAN, in similar to the classical counterpart."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "f5046df45d384b6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:02:09.860719Z",
     "start_time": "2024-02-21T14:02:07.297626Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:51:48.779879Z",
     "iopub.status.busy": "2024-05-07T14:51:48.779417Z",
     "iopub.status.idle": "2024-05-07T14:51:50.187233Z",
     "shell.execute_reply": "2024-05-07T14:51:50.186513Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Quantum-classical hybrid trained generator accuracy: 100.00%%\n"
     ]
    }
   ],
   "source": [
    "generator = q_gen\n",
    "checkpoint = torch.load(path / \"resources/q_generator_model_bs64.pth\")\n",
    "generator.load_state_dict(checkpoint)\n",
    "num_samples = 10\n",
    "z = torch.bernoulli(torch.rand(num_samples, 4))\n",
    "gen_data = generator(z)\n",
    "\n",
    "accuracy = evaluate_generator(samples=gen_data)\n",
    "print(f\"Quantum-classical hybrid trained generator accuracy: {accuracy:.2%}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "77cfc50360a0cecc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-21T14:02:12.505574Z",
     "start_time": "2024-02-21T14:02:09.861535Z"
    },
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2024-05-07T14:51:50.191651Z",
     "iopub.status.busy": "2024-05-07T14:51:50.190524Z",
     "iopub.status.idle": "2024-05-07T14:51:51.587581Z",
     "shell.execute_reply": "2024-05-07T14:51:51.586877Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1.],\n",
      "        [1., 1., 1., 1.]], grad_fn=<RoundBackward0>)\n"
     ]
    }
   ],
   "source": [
    "z = random_fake_data_for_qgan(10, 4)\n",
    "gen_data = generator(z)\n",
    "print(gen_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c18f8867230ac7d6",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "***\n",
    "Why do you think the accuracy is so high?\n",
    "\n",
    "Answer$\\rightarrow$ the system had chosen a metastable pathway where no violation of the rules occur! Try: longer training, different set of hyperparameters, different architecture, etc."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
